Source code for selfeeg.losses.losses

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ["barlow_loss", "byol_loss", "moco_loss", "simclr_loss", "simsiam_loss", "vicreg_loss"]


[docs] def simclr_loss( projections: torch.Tensor, projections_norm: bool = True, temperature: float = 0.15, ) -> torch.Tensor: """ ``simclr_loss`` computes the normalized temperature-scaled cross entropy loss [NTXent]_ , which is used in many contrastive learning algorithm. It is basically a simple implementation of the InfoNCE_loss provided in the official simCLR repository [simgit]_ using only torch functions. Parameters ---------- projections: torch.Tensor 2D Tensor where projections[0:N/2] are the projections of one batch augmented version and projections[N/2:] are the projections of the other batch augmented version projections_norm: bool, optional Whether to normalize the projections or not. Default = True temperature: float, optional Temperature coefficient of the NTX_ent loss (See references to check loss formula). Default = 0.15 Returns ------- loss: torch.Tensor The calculated loss. Note ---- Looking at some implementations (e.g. the one in lightlyAI), the returned loss seems to be double. However the function add_contrastive_loss in the original repo returns the same value as this implementation. References ---------- .. [NTXent] Chen et al. A Simple Framework for Contrastive Learning of Visual Representations. (2020). https://doi.org/10.48550/arXiv.2002.05709 .. [simgit] To check the original tensorflow implementation visit the following repository: https://github.com/google-research/simclr (look at the function add_contrastive_loss in objective.py) Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> projections = torch.randn(64, 32) >>> loss = losses.simclr_loss(projections) >>> print(loss) # will return 10.2866 """ if projections_norm: # L2 norm along first dimension projections = F.normalize(projections, p=2.0, dim=1) proj1, proj2 = torch.split(projections, int(projections.shape[0] / 2)) N = proj1.shape[0] labels = torch.eye(N, N * 2).to(device=projections.device) masks = torch.eye(N).to(device=projections.device) nn = torch.matmul(proj1, torch.transpose(proj1, 0, 1)) / temperature nn = nn - (masks * 1e9) mm = torch.matmul(proj2, torch.transpose(proj2, 0, 1)) / temperature mm = mm - (masks * 1e9) nm = torch.matmul(proj1, torch.transpose(proj2, 0, 1)) / temperature mn = torch.matmul(proj2, torch.transpose(proj1, 0, 1)) / temperature loss_1 = F.cross_entropy(torch.cat([nm, nn], 1), labels, reduction="mean") loss_2 = F.cross_entropy(torch.cat([mn, mm], 1), labels, reduction="mean") loss = loss_1 + loss_2 return loss
[docs] def simsiam_loss( p1: torch.Tensor, z1: torch.Tensor, p2: torch.Tensor, z2: torch.Tensor, projections_norm: bool = True, ) -> torch.Tensor: """ Simple implementation of the SimSiam [simsiam]_ loss function with the possibility to not normalize tensors. Official repo can be found here [siamgit]_ Parameters ---------- p1: torch.Tensor 2D Tensor with one augmented batch predictor output. z1: torch.Tensor 2D Tensor with one augmented batch projection output. p2: torch.Tensor Same as p1 but with the other augmented batch. z2: torch.Tensor Same as z1 with the other augmented batch. projections_norm: bool, optional Whether to normalize the projections or not. Default= True Returns ------- loss: torch.Tensor The calculated loss. References ---------- .. [siamgit] Original github repo: https://github.com/facebookresearch/simsiam .. [simsiam] Original paper: Chen & He. Exploring Simple Siamese Representation Learning. https://arxiv.org/abs/2011.10566 Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> p1 = torch.randn(64, 32) >>> z1 = torch.randn(64, 32) >>> p2 = torch.randn(64, 32) >>> z2 = torch.randn(64, 32) >>> loss = losses.simsiam_loss(p1,z1,p2,z2) >>> print(loss) # will return -0.0161 """ if projections_norm: p1 = F.normalize(p1, p=2.0, dim=1) z1 = F.normalize(z1.detach(), p=2.0, dim=1) p2 = F.normalize(p2, p=2.0, dim=1) z2 = F.normalize(z2.detach(), p=2.0, dim=1) D1 = -(p1 * z2).sum(dim=1).mean() D2 = -(p2 * z1).sum(dim=1).mean() loss = 0.5 * D1 + 0.5 * D2 return loss
[docs] def moco_loss( q: torch.Tensor, k: torch.Tensor, queue: torch.Tensor = None, projections_norm: bool = True, temperature: float = 0.07, ) -> torch.Tensor: """ Simple implementation of the MoCo loss function [moco2]_. It is the InfoNCE loss with dot product as similarity and memory bank as negative samples. If no queue related to the memory bank is given, MoCo v3 [moco3]_ loss calculation is performed. Note that the real MoCo v3 loss is calculated by calling the function 2 times (with different q and k tensors) and summing up the results. Parameters ---------- q: torch.Tensor 2D (NxC) Tensor with the queries, i.e. one augmented batch predictor or projection_head output. N = batch size, C = number of features. k: torch.Tensor 2D (NxC) Tensor with the keys, i.e. one augmented batch projection_head output which will be added to the memory bank. N = batch size , C = number of features. queue: torch.Tensor, optional 2D (CxK) Tensor with the memory bank, i.e. a collection of previous augmented batch projection_head outputs which act as negative samples. C = number of features, K = memory bank size. Default = None projections_norm: bool, optional Whether to normalize the projections or not. Default = True temperature: float, optional Temperature coefficient of the NTX_ent loss. Default = 0.15 Returns ------- loss: torch.Tensor The calculated loss. References ---------- .. [moco2] K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick, “Momentum contrast for unsupervised visual representation learning,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9729–9738, 2020. .. [moco3] X. Chen, H. Fan, R. Girshick, and K. He, “Improved baselines with momentum contrastive learning,” arXiv preprint arXiv:2003.04297, 2020. Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> q = torch.randn(64, 32) >>> k = torch.randn(64, 32) >>> queue = torch.randn(32, 128) >>> loss = losses.moco_loss(q, k, queue) >>> print(loss) # will return 17.1668 >>> loss = losses.moco_loss(q, k) >>> print(loss) # will return 1.4349 """ N, C = q.shape # normalize if projections_norm: q = nn.functional.normalize(q, dim=1) k = nn.functional.normalize(k, dim=1) # if no queue is given, run MoCo v3 loss # (note that MoCo v3 is MoCo_loss(q1,k2) + MoCo_loss(q2,k1) if queue == None: logits = torch.einsum("nc,mc->nm", [q, k]) / temperature N = logits.shape[0] # batch size per GPU labels = torch.arange(N, dtype=torch.long, device=logits.device) return nn.CrossEntropyLoss()(logits, labels) * (2 * temperature) # positive logits: Nx1 l_pos = torch.bmm(q.view(N, 1, C), k.view(N, C, 1)).squeeze(-1) # negative logits: NxK l_neg = torch.matmul(q, queue.detach()) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) # apply temperature logits /= temperature # labels: positive key indicators labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device) loss = F.cross_entropy(logits, labels, reduction="mean") return loss
[docs] def byol_loss( p1: torch.Tensor, z1: torch.Tensor, p2: torch.Tensor, z2: torch.Tensor, projections_norm: bool = True, ) -> torch.Tensor: """ Simple pytorch implementation of the BYOL loss function presented in [BYOL]_ . Parameters ---------- p1: torch.Tensor 2D Tensor with one augmented batch predictor output. z1: torch.Tensor 2D Tensor with one augmented batch projection output. p2: torch.Tensor Same as p1 but with the other augmented batch. z2: torch.Tensor Same as z1 with the other augmented batch. projections_norm: bool, optional Whether to normalize the projections or not. Default= True Returns ------- loss: torch.Tensor The calculated loss. References ---------- .. [BYOL] J.-B. Grill, F. Strub, F. Altché, C. Tallec, P. Richemond, E. Buchatskaya, C. Doersch, B. Avila Pires, Z. Guo, M. Gheshlaghi Azar, et al., “Bootstrap your own latent - a new approach to self-supervised learning,” Advances in neural information processing systems, vol. 33, pp. 21271– 21284, 2020. Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> p1 = torch.randn(64, 32) >>> z1 = torch.randn(64, 32) >>> p2 = torch.randn(64, 32) >>> z2 = torch.randn(64, 32) >>> loss = losses.byol_loss(p1,z1,p2,z2) >>> print(loss) # will return 3.9357 """ if projections_norm: p1 = F.normalize(p1, p=2.0, dim=1) z1 = F.normalize(z1.detach(), p=2.0, dim=1) p2 = F.normalize(p2, p=2.0, dim=1) z2 = F.normalize(z2.detach(), p=2.0, dim=1) loss1 = 2 - 2 * (p1 * z2).sum(dim=-1) loss2 = 2 - 2 * (p2 * z1).sum(dim=-1) loss = loss1 + loss2 return loss.mean()
[docs] def barlow_loss( z1: torch.Tensor, z2: torch.Tensor = None, lambda_coeff: float = 5e-3, ) -> torch.Tensor: """ Pytorch implementation of the Barlow Twins loss function as presented in [barlow]_ . Parameters ---------- z1: torch.tensor 2D tensor with projections of one augmented version of the batch. z2: torch.tensor, optional 2D projections of the other augmented version of the batch. Can be none if z1 and z2 are concatenated. In this case internal split is done. Default = None lambda_coeff: float, optional Off diagonal scaling factor described in the paper. Default = 5e-3 Returns ------- loss: torch.Tensor The calculated loss. References ---------- .. [barlow] J. Zbontar, L. Jing, I. Misra, Y. LeCun, and S. Deny, “Barlow twins: Self-supervised learning via redundancy reduction,” in International Conference on Machine Learning, pp. 12310–12320, PMLR, 2021. Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> z1 = torch.randn(64, 32) >>> z2 = torch.randn(64, 32) >>> loss = losses.barlow_loss(z1,z2) >>> print(loss) # will return 31.6141 """ if z2 == None: z1, z2 = torch.split(z1, int(z1.shape[0] / 2)) N, D = z1.shape z1_norm = (z1 - z1.mean(0)) / z1.std(0) z2_norm = (z2 - z2.mean(0)) / z2.std(0) c_mat = (z1_norm.T @ z2_norm) / N c_mat2 = c_mat.pow(2) loss = ( D - 2 * torch.trace(c_mat) + lambda_coeff * torch.sum(c_mat**2) + (1 - lambda_coeff) * torch.trace(c_mat**2) ) return loss
[docs] def vicreg_loss( z1: torch.Tensor, z2: torch.Tensor = None, Lambda: float = 25, Mu: float = 25, Nu: float = 1, epsilon: float = 1e-4, ) -> torch.Tensor: """ Pytorch implementation of the VICReg loss function [VIC]_ . Parameters ---------- z1: torch.tensor 2D tensor with projections of one augmented version of the batch. z2: torch.tensor, optional 2D projections of the other augmented version of the batch. Can be none if z1 and z2 are cat together. In this case internal split is done, but be sure that the first dimension can be divided by 2. Default = None Lambda: float, optional Coefficient applied to the invariant loss. Default = 25 Mu: float, optional Coefficient applied to the variance loss . Default = 25 Nu: float, optional Coefficient applied to the covariance. Default = 1 epsilon: float, optional Value summed to the variance for stability purposes. Default = 1e-4 Returns ------- loss: torch.Tensor The calculated loss. References ---------- .. [VIC] A. Bardes, J. Ponce, and Y. LeCun, “Vicreg: Variance-invariance-covariance regularization for self-supervised learning,” arXiv preprint arXiv:2105.04906, 2021. Example ------- >>> import torch >>> import selfeeg.losses >>> torch.manual_seed(1234) >>> z1 = torch.randn(64, 32) >>> z2 = torch.randn(64, 32) >>> loss = losses.vicreg_loss(z1,z2) >>> print(loss) # will return 53.0773 """ if z2 == None: z1, z2 = torch.split(z1, int(z1.shape[0] / 2)) N, D = z1.shape z1 = z1 - z1.mean(dim=0) z2 = z2 - z2.mean(dim=0) # invariance loss sim_loss = F.mse_loss(z1, z2) # variance loss std_z1 = torch.sqrt(z1.var(dim=0) + epsilon) std_z2 = torch.sqrt(z2.var(dim=0) + epsilon) std_loss = (torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))) / 2 # covariance loss cov_z1 = (z1.T @ z1) / (N - 1) cov_z1[range(D), range(D)] = 0.0 cov_z2 = (z2.T @ z2) / (N - 1) cov_z2[range(D), range(D)] = 0.0 cov_loss = cov_z1.pow_(2).sum() / D + cov_z2.pow_(2).sum() / D loss = Lambda * sim_loss + Mu * std_loss + Nu * cov_loss return loss