TripletMarginWithDistanceLoss — PyTorch 2.7 documentation (original) (raw)

class torch.nn.TripletMarginWithDistanceLoss(*, distance_function=None, margin=1.0, swap=False, reduction='mean')[source][source]

Creates a criterion that measures the triplet loss given input tensors aa, pp, and nn (representing anchor, positive, and negative examples, respectively), and a nonnegative, real-valued function (“distance function”) used to compute the relationship between the anchor and positive example (“positive distance”) and the anchor and negative example (“negative distance”).

The unreduced loss (i.e., with reduction set to 'none') can be described as:

ℓ(a,p,n)=L={l1,…,lN}⊤,li=max⁡{d(ai,pi)−d(ai,ni)+margin,0}\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}

where NN is the batch size; dd is a nonnegative, real-valued function quantifying the closeness of two tensors, referred to as the distance_function; and marginmargin is a nonnegative margin representing the minimum difference between the positive and negative distances that is required for the loss to be 0. The input tensors have NN elements each and can be of any shape that the distance function can handle.

If reduction is not 'none'(default 'mean'), then:

ℓ(x,y)={mean⁡(L),if reduction=‘mean’;sum⁡(L),if reduction=‘sum’.\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}

See also TripletMarginLoss, which computes the triplet loss for input tensors using the lpl_p distance as the distance function.

Parameters

Shape:

Examples:

Initialize embeddings

embedding = nn.Embedding(1000, 128) anchor_ids = torch.randint(0, 1000, (1,)) positive_ids = torch.randint(0, 1000, (1,)) negative_ids = torch.randint(0, 1000, (1,)) anchor = embedding(anchor_ids) positive = embedding(positive_ids) negative = embedding(negative_ids)

Built-in Distance Function

triplet_loss =
nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) output = triplet_loss(anchor, positive, negative) output.backward()

Custom Distance Function

def l_infinity(x1, x2): return torch.max(torch.abs(x1 - x2), dim=1).values

triplet_loss = ( nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)) output = triplet_loss(anchor, positive, negative) output.backward()

Custom Distance Function (Lambda)

triplet_loss = ( nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))) output = triplet_loss(anchor, positive, negative) output.backward()

Reference:

V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html