moco_loss
- selfeeg.losses.losses.moco_loss(q: Tensor, k: Tensor, queue: Tensor = None, projections_norm: bool = True, temperature: float = 0.07) Tensor[source]
Simple implementation of the MoCo loss function [moco2]. It is the InfoNCE loss with dot product as similarity and memory bank as negative samples. If no queue related to the memory bank is given, MoCo v3 [moco3] loss calculation is performed. Note that the real MoCo v3 loss is calculated by calling the function 2 times (with different q and k tensors) and summing up the results.
- Parameters:
q (torch.Tensor) – 2D (NxC) Tensor with the queries, i.e. one augmented batch predictor or projection_head output. N = batch size, C = number of features.
k (torch.Tensor) – 2D (NxC) Tensor with the keys, i.e. one augmented batch projection_head output which will be added to the memory bank. N = batch size , C = number of features.
queue (torch.Tensor, optional) –
2D (CxK) Tensor with the memory bank, i.e. a collection of previous augmented batch projection_head outputs which act as negative samples. C = number of features, K = memory bank size.
Default = None
projections_norm (bool, optional) –
Whether to normalize the projections or not.
Default = True
temperature (float, optional) –
Temperature coefficient of the NTX_ent loss.
Default = 0.15
- Returns:
loss (torch.Tensor) – The calculated loss.
References
[moco2]K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick, “Momentum contrast for unsupervised visual representation learning,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9729–9738, 2020.
[moco3]X. Chen, H. Fan, R. Girshick, and K. He, “Improved baselines with momentum contrastive learning,” arXiv preprint arXiv:2003.04297, 2020.
Example
>>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> q = torch.randn(64, 32) >>> k = torch.randn(64, 32) >>> queue = torch.randn(32, 128) >>> loss = losses.moco_loss(q, k, queue) >>> print(loss) # will return 17.1668 >>> loss = losses.moco_loss(q, k) >>> print(loss) # will return 1.4349