ShallowNet

class selfeeg.models.zoo.ShallowNet(nb_classes: int, Chans: int, Samples: int, F: int = 40, K1: int = 25, Pool: int = 75, p: float = 0.2, return_logits: bool = True, seed: int = None)[source]

Pytorch implementation of the ShallowNet model.

Original paper can be found here [shall] . The expected input is a 3D tensor with size (Batch x Channels x Samples).

Parameters:
  • nb_classes (int) – The number of classes. If less than 2, a binary classification problem is considered (output dimensions will be [batch, 1] in this case).

  • Chans (int) – The number of EEG channels.

  • Samples (int) – The sample length. It will be used to calculate the embedding size (for head initialization).

  • F (int, optional) –

    The number of output filters in the temporal convolution layer.

    Default = 8

  • K1 (int, optional) –

    The length of the temporal convolutional layer.

    Default = 25

  • Pool (int, optional) –

    The temporal pooling kernel size.

    Default = 75

  • p (float, optional) –

    The dropout probability. Must be in [0,1)

    Default= 0.2

  • return_logits (bool, optional) –

    Whether to return the output as logit or probability. It is suggested to not use False as the pytorch crossentropy applies the softmax internally.

    Default = True

  • seed (int, optional) –

    A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set

    Default = None

Note

In this implementation, the number of channels is an argument. However, in the original paper authors preprocess EEG data by selecting a subset of only 21 channels. Since the net is very minimalist, please follow the authors’ notes.

References

[shall]

Schirrmeister et al., Deep Learning with convolutional neural networks for decoding and visualization of EEG pathology, arXiv:1708.08012

Example

>>> import selfeeg.models
>>> import torch
>>> x = torch.randn(4,8,512)
>>> mdl = models.ShallowNet(4,8,512)
>>> out = mdl(x)
>>> print(out.shape) # shoud return torch.Size([4, 4])
>>> print(torch.isnan(out).sum()) # shoud return 0