barlow_loss

selfeeg.losses.losses.barlow_loss(z1: Tensor, z2: Tensor = None, lambda_coeff: float = 0.005) Tensor[source]

Pytorch implementation of the Barlow Twins loss function as presented in [barlow] .

Parameters:
  • z1 (torch.tensor) – 2D tensor with projections of one augmented version of the batch.

  • z2 (torch.tensor, optional) –

    2D projections of the other augmented version of the batch. Can be none if z1 and z2 are concatenated. In this case internal split is done.

    Default = None

  • lambda_coeff (float, optional) –

    Off diagonal scaling factor described in the paper.

    Default = 5e-3

Returns:

loss (torch.Tensor) – The calculated loss.

References

[barlow]

J. Zbontar, L. Jing, I. Misra, Y. LeCun, and S. Deny, “Barlow twins: Self-supervised learning via redundancy reduction,” in International Conference on Machine Learning, pp. 12310–12320, PMLR, 2021.

Example

>>> import torch
>>> import selfeeg.losses
>>> torch.manual_seed(1234)
>>> z1 = torch.randn(64, 32)
>>> z2 = torch.randn(64, 32)
>>> loss = losses.barlow_loss(z1,z2)
>>> print(loss) # will return 31.6141