STNet

class selfeeg.models.zoo.STNet(nb_classes: int, Samples: int, grid_size: int = 9, F: int = 256, kernlength: int = 5, dropRate: float = 0.5, bias: bool = True, dense_size: int = 1024, return_logits: bool = True, seed: int = None)[source]

Pytorch implementation of the STNet model.

Original paper can be found here [stnet] . Another implementation can be found here [stnetgit] .

The expected input is a 4D tensor with size (Batch x Samples x Grid_width x Grid_width), i.e. the classical 2d matrix with rows as channels and columns as samples is rearranged in a 3d tensor where the first is the Sample dimension and the last 2 dimensions are the channel dim rearranged in a 2d grid. Check the original paper for a better understanding of the input.

Parameters:
  • nb_classes (int) – The number of classes. If less than 2, a binary classification problem is considered. (output dimensions will be [batch, 1] in this case)

  • Samples (int) – The sample length. It will be used to calculate the embedding size

  • grid_size (int, optional) –

    The grid size, i.e. the size of the EEG channel 2D grid.

    Default = 9

  • F (int, optional) –

    The number of output filters in the convolutional layer.

    Default = 256

  • kernLength (int, optional) –

    The length of the convolutional layer.

    Default = 5

  • dropRate (float, optional) –

    The dropout percentage in range [0,1].

    Default = 0.5

  • bias (bool, optional) –

    If True, adds a learnable bias to the convolutional layers.

    Default = True

  • dense_size (int, optional) –

    The output size of the first dense layer.

    Default = 1024

  • return_logits (bool, optional) –

    Whether to return the output as logit or probability. It is suggested to not use False as the pytorch crossentropy applies the softmax internally.

    Default = True

  • seed (int, optional) –

    A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set

    Default = None

References

Example

>>> import selfeeg.models
>>> import torch
>>> x = torch.randn(4,128,9,9)
>>> mdl = models.STNet(4,128)
>>> out = mdl(x)
>>> print(out.shape) # shoud return torch.Size([4, 4])
>>> print(torch.isnan(out).sum()) # shoud return 0