Page Nav

HIDE

Breaking News:

latest

Ads Place

Triplet Loss — Advanced Intro

https://ift.tt/1jYs0qe Triplet Loss — Advanced Intro What are the advantages of Triplet Loss over Contrastive loss, and how to efficiently...

https://ift.tt/1jYs0qe

Triplet Loss — Advanced Intro

What are the advantages of Triplet Loss over Contrastive loss, and how to efficiently implement it?

Paths followed by moving points under Triplet Loss. Image by author.

Triplet Loss was first introduced in FaceNet: A Unified Embedding for Face Recognition and Clustering in 2015, and it has been one of the most popular loss functions for supervised similarity or metric learning ever since. In its simplest explanation, Triplet Loss encourages that dissimilar pairs be distant from any similar pairs by at least a certain margin value. Mathematically, the loss value can be calculated as L=max(d(a, p) - d(a, n) + m, 0), where:

  • p, i.e., positive, is a sample that has the same label as a, i.e., anchor,
  • n, i.e., negative, is another sample that has a label different from a,
  • d is a function to measure the distance between these three samples,
  • and m is a margin value to keep negative samples far apart.

The paper uses Euclidean distance, but it is equally valid to use any other distance metric, e.g., cosine distance.

The function has a learning objective that can be visualized as in the following:

Triplet Loss objective. Image by author.
Triplet Loss objective. Image by author

Notice that Triplet Loss does not have a side effect of urging to encode anchor and positive samples into the same point in the vector space as in Contrastive Loss. This lets Triplet Loss tolerate some intra-class variance, unlike Contrastive Loss, as the latter forces the distance between an anchor and any positive essentially to 0. In other terms, Triplet Loss allows to stretch clusters in such a way as to include outliers while still ensuring a margin between samples from different clusters, e.g., negative pairs.

Additionally, Triplet Loss is less greedy. Unlike Contrastive Loss, it is already satisfied when different samples are easily distinguishable from similar ones. It does not change the distances in a positive cluster if there is no interference from negative examples. This is due to the fact that Triplet Loss tries to ensure a margin between distances of negative pairs and distances of positive pairs. However, Contrastive Loss takes into account the margin value only when comparing dissimilar pairs, and it does not care at all where similar pairs are at that moment. This means that Contrastive Loss may reach a local minimum earlier, while Triplet Loss may continue to organize the vector space in a better state.

Let’s demonstrate how two loss functions organize the vector space by animations. For simpler visualization, the vectors are represented by points in a 2-dimensional space, and they are selected randomly from a normal distribution.

Animation that shows how Contrastive Loss moves points in the course of training. Image by author.
Animation that shows how Contrastive Loss moves points in the course of training. Image by author.
Animation that shows how Triplet Loss moves points in the course of training.
Animation that shows how Triplet Loss moves points in the course of training. Image by author.

From mathematical interpretations of the two-loss functions, it is clear that Triplet Loss is theoretically stronger, but Triplet Loss has additional tricks that help it work better. Most importantly, Triplet Loss introduce online triplet mining strategies, e.g., automatically forming the most useful triplets.

Why triplet mining matters?

The formulation of Triplet Loss demonstrates that it works on three objects at a time:

  1. anchor,
  2. positive - a sample that has the same label as the anchor,
  3. and negative - a sample with a different label from the anchor and the positive.

In a naive implementation, we could form such triplets of samples at the beginning of each epoch and then feed batches of such triplets to the model throughout that epoch. This is called “offline strategy.” However, this would not be so efficient for several reasons:

  • It needs to pass 3n samples to get a loss value of n triplets.
  • Not all these triplets will be useful for the model to learn anything, e.g., yielding a positive loss value.
  • Even if we form “useful” triplets at the beginning of each epoch with one of the methods that I will be implementing in this series, they may become “useless” at some point in the epoch as the model weights will be constantly updated.

Instead, we can get a batch of n samples and their associated labels, and form triplets on the fly. That is called "online strategy." Normally, this gives n^3 possible triplets, but only a subset of such possible triplets will be actually valid. Even in this case, we will have a loss value calculated from much more triplets than the offline strategy.

Given a triplet of (a, p, n), it is valid only if:

  1. a and p has the same label,
  2. a and p are distinct samples,
  3. and n has a different label from a and p.

These constraints may seem to be requiring expensive computation with nested loops, but it can be efficiently implemented with tricks such as distance matrix, masking, and broadcasting. The rest of this series will focus on the implementation of these tricks.

Distance matrix

A distance matrix is a matrix of shape (n, n) to hold distance values between all possible pairs made from items in two n-sized collections. This matrix can be used to vectorize calculations that would need inefficient loops otherwise. Its calculation can be optimized as well, and we will implement Euclidean Distance Matrix Trick (PDF) explained by Samuel Albanie. You may want to read this three-page document for the full intuition of the trick, but a brief explanation is as follows:

  1. Calculate the dot product of two collections of vectors, e.g., embeddings in our case.
  2. Extract the diagonal from this matrix that holds the squared Euclidean norm of each embedding.
  3. Calculate the squared Euclidean distance matrix based on the following equation: ||a - b||^2 = ||a||^2 - 2 ⟨a, b⟩ + ||b||^2
  4. Get the square root of this matrix for non-squared distances.

We will implement it in PyTorch, so let’s start with imports.

Invalid triplet masking

Now that we can compute a distance matrix for all possible pairs of embeddings in a batch, we can apply broadcasting to enumerate distance differences for all possible triplets and represent them in a tensor of shape (batch_size, batch_size, batch_size). However, only a subset of these n^3 triplets are actually valid as I mentioned earlier, and we need a corresponding mask to compute the loss value correctly. We will implement such a helper function in three steps:

  1. Compute a mask for distinct indices, e.g., (i != j and j != k).
  2. Compute a mask for valid anchor-positive-negative triplets, e.g., labels[i] == labels[j] and labels[j] != labels[k].
  3. Combine two masks.

Batch-all strategy for online triplet mining

Now we are ready for actually implementing Triplet Loss itself. Triplet Loss involves several strategies to form or select triplets, and the simplest one is to use all valid triplets that can be formed from samples in a batch. This can be achieved in four easy steps thanks to utility functions we’ve already implemented:

  1. Get a distance matrix of all possible pairs that can be formed from embeddings in a batch.
  2. Apply broadcasting to this matrix to compute loss values for all possible triplets.
  3. Set loss values of invalid or easy triplets to 0.
  4. Average the remaining positive values to return a scalar loss.

I will start by implementing this strategy, and more complex ones will follow as separate posts.

Conclusion

I mentioned that Triplet Loss is different from Contrastive Loss not only mathematically but also in its sample selection strategies, and I implemented the batch-all strategy for online triplet mining in this post efficiently by using several tricks. There are other more complicated strategies such as batch-hard and batch-semihard mining, but their implementations, and discussions of the tricks I used for efficiency in this post, are worth separate posts of their own. The future posts will cover such topics and additional discussions on some tricks to avoid vector collapsing and control intra-class and inter-class variance. Meanwhile, you can join Qdrant’s Discord server to discuss, learn more and ask questions on metric learning.

Join the qdrant Discord Server!


Triplet Loss — Advanced Intro was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.



from Towards Data Science - Medium https://ift.tt/5EeB3Qg
via RiYo Analytics

ليست هناك تعليقات

Latest Articles