Loss Overview — Sentence Transformers documentation (original) (raw)

Warning

To train a SparseEncoder, you need either SpladeLoss, CachedSpladeLoss, or CSRLoss, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is SparseMSELoss, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher’s sparse embedding.

Sparse specific Loss Functions

SPLADE Loss

The SpladeLoss implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to balance effectiveness and efficiency:

  1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss, SparseMarginMSELoss and SparseDistillKLDivLoss commonly used.
  2. Regularization loss: FlopsLoss is used to control sparsity, but supports custom regularizers.
    • query_regularizer and document_regularizer can be set to any custom regularization loss.
    • query_regularizer_threshold and document_regularizer_threshold can be set to control the sparsity strictness for queries and documents separately, setting the regularization loss to zero if an embedding has less than the threshold number of active (non-zero) dimensions.

Cached SPLADE Loss

The CachedSpladeLoss is a variant of the SPLADE loss adopting GradCache, which allows for much larger batch sizes without additional GPU memory usage. It achieves this by computing and caching loss gradients in mini-batches.

Main losses that use in-batch negatives, primarily SparseMultipleNegativesRankingLoss, benefit heavily from larger batch sizes, as it results in more negatives and a stronger training signal.

CSR Loss

If you are using the SparseAutoEncoder module, then you have to use the CSRLoss (Contrastive Sparse Representation Loss). It combines two components:

  1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss used in the CSR Paper.
  2. Reconstruction loss: CSRReconstructionLoss is used to ensure that sparse representation can faithfully reconstruct the original dense embeddings.

Loss Table

Loss functions play a critical role in the performance of your fine-tuned model. Sadly, there is no “one size fits all” loss function. Ideally, this table should help narrow down your choice of loss function(s) by matching them to your data formats.

Note

You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, (sentence_A, sentence_B) pairs with class labels can be converted into (anchor, positive, negative) triplets by sampling sentences with the same or different classes.

Legend: Loss functions marked with are commonly recommended default choices.

Inputs Labels Appropriate Loss Functions
(anchor, positive) pairs none SparseMultipleNegativesRankingLoss
(sentence_A, sentence_B) pairs float similarity score between 0 and 1 SparseCoSENTLossSparseAnglELossSparseCosineSimilarityLoss
(anchor, positive, negative) triplets none SparseMultipleNegativesRankingLossSparseTripletLoss
(anchor, positive, negative_1, ..., negative_n) none SparseMultipleNegativesRankingLoss

Distillation

These loss functions are specifically designed to be used when distilling the knowledge from one model into another. This is rather commonly used when training Sparse embedding models.

Texts Labels Appropriate Loss Functions
sentence model sentence embeddings SparseMSELoss
(sentence_1, sentence_2, ..., sentence_N) model sentence embeddings SparseMSELoss
(query, passage_one, passage_two) gold_sim(query, passage_one) - gold_sim(query, passage_two) SparseMarginMSELoss
(query, positive, negative_1, ..., negative_n) [gold_sim(query, positive) - gold_sim(query, negative_i) for i in 1..n] SparseMarginMSELoss
(query, positive, negative) [gold_sim(query, positive), gold_sim(query, negative)] SparseDistillKLDivLossSparseMarginMSELoss
(query, positive, negative_1, ..., negative_n) [gold_sim(query, positive), gold_sim(query, negative_i)...] SparseDistillKLDivLossSparseMarginMSELoss

Commonly used Loss Functions

In practice, not all loss functions get used equally often. The most common scenarios are:

Custom Loss Functions

Advanced users can create and train with their own loss functions. Custom loss functions only have a few requirements:

To get full support with the automatic model card generation, you may also wish to implement:

Consider inspecting existing loss functions to get a feel for how loss functions are commonly implemented.