Skip to content

Commit e26b0e2

Browse files
author
Jordan Stomps
committed
adding details to NTXentLoss documentation
1 parent 6b4d4d1 commit e26b0e2

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

docs/losses.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,13 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse
787787
- [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf){target=_blank}
788788
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank}
789789
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank}
790+
791+
In the equation below, loss is computed for each positive pair, `k_+`, in a batch, normalized by all pairs in the batch, `k_i in K`.
792+
For each `embeddings` with `labels` and `ref_emb` with `ref_labels`, positive pair `(embeddings[i], ref_emb[j])` are defined when `labels[i] == ref_labels[j]`.
793+
When `embeddings` and `ref_emb` are augmented versions of each other (e.g. SimCLR), `labels[i] == ref_labels[i]` (see [SelfSupervisedLoss](losses.md#selfsupervisedloss)).
794+
Note that multiple positive pairs can exist if the same label is present multiple times in `labels` and/or `ref_labels`.
795+
796+
Instead of passing labels (`NTXentLoss(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)`), `indices_tuple` could be passed (see [`pytorch_metric_learning.utils.loss_and_miner_utils.get_all_pairs_indices](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/utils/loss_and_miner_utils.py)).
790797
```python
791798
losses.NTXentLoss(temperature=0.07, **kwargs)
792799
```
@@ -799,6 +806,16 @@ losses.NTXentLoss(temperature=0.07, **kwargs)
799806

800807
* **temperature**: This is tau in the above equation. The MoCo paper uses 0.07, while SimCLR uses 0.5.
801808

809+
**Other info:**
810+
811+
For example, consider `labels = ref_labels = [0, 0, 1, 2]`. Two losses will be computed:
812+
813+
* Positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`.
814+
815+
* Positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`.
816+
817+
Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used.
818+
802819
**Default distance**:
803820

804821
- [```CosineSimilarity()```](distances.md#cosinesimilarity)

0 commit comments

Comments
 (0)