moco_loss

selfeeg.losses.losses.moco_loss(q: Tensor, k: Tensor, queue: Tensor = None, projections_norm: bool = True, temperature: float = 0.07) Tensor[source]

Simple implementation of the MoCo loss function [moco2]. It is the InfoNCE loss with dot product as similarity and memory bank as negative samples. If no queue related to the memory bank is given, MoCo v3 [moco3] loss calculation is performed. Note that the real MoCo v3 loss is calculated by calling the function 2 times (with different q and k tensors) and summing up the results.

Parameters:
  • q (torch.Tensor) – 2D (NxC) Tensor with the queries, i.e. one augmented batch predictor or projection_head output. N = batch size, C = number of features.

  • k (torch.Tensor) – 2D (NxC) Tensor with the keys, i.e. one augmented batch projection_head output which will be added to the memory bank. N = batch size , C = number of features.

  • queue (torch.Tensor, optional) –

    2D (CxK) Tensor with the memory bank, i.e. a collection of previous augmented batch projection_head outputs which act as negative samples. C = number of features, K = memory bank size.

    Default = None

  • projections_norm (bool, optional) –

    Whether to normalize the projections or not.

    Default = True

  • temperature (float, optional) –

    Temperature coefficient of the NTX_ent loss.

    Default = 0.15

Returns:

loss (torch.Tensor) – The calculated loss.

References

[moco2]

K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick, “Momentum contrast for unsupervised visual representation learning,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9729–9738, 2020.

[moco3]

X. Chen, H. Fan, R. Girshick, and K. He, “Improved baselines with momentum contrastive learning,” arXiv preprint arXiv:2003.04297, 2020.

Example

>>> import torch
>>> import selfeeg.losses
>>> torch.manual_seed(1234)
>>> q = torch.randn(64, 32)
>>> k = torch.randn(64, 32)
>>> queue = torch.randn(32, 128)
>>> loss = losses.moco_loss(q, k, queue)
>>> print(loss) # will return 17.1668
>>> loss = losses.moco_loss(q, k)
>>> print(loss) # will return 1.4349