simsiam_loss

selfeeg.losses.losses.simsiam_loss(p1: Tensor, z1: Tensor, p2: Tensor, z2: Tensor, projections_norm: bool = True) Tensor[source]

Simple implementation of the SimSiam [simsiam] loss function with the possibility to not normalize tensors. Official repo can be found here [siamgit]

Parameters:
  • p1 (torch.Tensor) – 2D Tensor with one augmented batch predictor output.

  • z1 (torch.Tensor) – 2D Tensor with one augmented batch projection output.

  • p2 (torch.Tensor) – Same as p1 but with the other augmented batch.

  • z2 (torch.Tensor) – Same as z1 with the other augmented batch.

  • projections_norm (bool, optional) –

    Whether to normalize the projections or not.

    Default= True

Returns:

loss (torch.Tensor) – The calculated loss.

References

[simsiam]

Original paper: Chen & He. Exploring Simple Siamese Representation Learning. https://arxiv.org/abs/2011.10566

Example

>>> import torch
>>> import selfeeg.losses
>>> torch.manual_seed(1234)
>>> p1 = torch.randn(64, 32)
>>> z1 = torch.randn(64, 32)
>>> p2 = torch.randn(64, 32)
>>> z2 = torch.randn(64, 32)
>>> loss = losses.simsiam_loss(p1,z1,p2,z2)
>>> print(loss) # will return -0.0161