FBCNet
- class selfeeg.models.zoo.FBCNet(nb_classes: int, Chans: int, Samples: int, Fs: int, FilterBands: int = 9, FilterRange: float = 4, FilterType: str = 'Cheby2', FilterStopRippple: int = 30, FilterPassRipple: int = 3, FilterRangeTol: int = 2, FilterSkipFirst=True, D: int = 32, TemporalType: str = 'logvar', TemporalStride: int = 4, batch_momentum: float = 0.1, depthwise_max_norm: float = None, linear_max_norm: float = None, classifier: Module = None, return_logits: bool = True, seed: int = None)[source]
Pytorch implementation of the FBCNet model.
FBCNet paper can be found here [fbcnet] . The official implementation can be found here [gitfbc] .
The expected input is a 3D tensor with size (Batch x Channels x Samples).
Filter operation is applied through the torchaudio filtfilt function. Do not use too strict filter settings as this might generate nan or too high values.
- Parameters:
nb_classes (int) – The number of classes. If less than 2, a binary classification problem is considered (output dimensions will be [batch, 1] in this case).
Chans (int) – The number of EEG channels.
Samples (int) – The number of EEG samples.
Fs (int or float) – The EEG sampling rate.
FilterBands (int, optional) –
The number of filters to apply to the original signal. It is used by the FilterBank layer.
Default = 9
FilterRange (int or float, optional) –
The passband of each filter, given in Hz.It is used by the FilterBank layer.
Default = 4
FilterType (str, optional) –
The type of filter to use. Allowed arguments are the same as described in the get_filter_coeff() function of the selfeeg.augmentation.functional submodule (‘butter’, ‘ellip’, ‘cheby1’, ‘cheby2’). It is used by the FilterBank layer.
Default = ‘cheby1’
FilterStopRipple (int or float, optional) –
Ripple at stopband in decibel. It is used by the FilterBank layer.
Default = 30
FilterPassRipple (int or float, optional) –
Ripple at passband in decibel. It is used by the FilterBank layer.
Default = 30
FilterRangeTol (int or float, optional) –
The filter transition bandwidth in Hz. It is used by the FilterBank layer.
Default = 2
FilterSkipFirst (bool, optional) –
If True, skips the first filter with passband equal to [0, Range] Hz. The number of filters specified in Bands will still be preserved. It is used by the FilterBank layer.
Default = True
D (int, optional) –
The depth of the depthwise convolutional layer.
Default = 2
TemporalType (str, optional) –
The type of temporal feature extraction layer to use. Accepted values are ‘max’, ‘mean’, ‘std’, ‘var’, or ‘logvar’.
Default = ‘logvar’
TemporalStride (int, optional) –
The signal length output dimension of the temporal feature extraction layer. Kernel length and layer stride will be calculated based on the given input. Be sure that Sample is a multiple of this attribute.
Default = 4
batch_momentum (float, optional) –
The batch normalization momentum.
Default = 0.1
depthwise_max_norm (float, optional) –
The maximum norm each filter can have in the depthwise block. If None no constraint will be included.
Default = None
linear_max_norm (float, optional) –
The maximum norm each filter can have in the final dense layer. If None no constraint will be included.
Default = None
classifier (nn.Module, optional) –
A custom block to apply after the encoder instead of the classical linear layer. Must be an istance of an nn.Module. If none a standard linear layer is applied.
Default = None
return_logits (bool, optional) –
Whether to return the output as logit or probability. It is suggested to not use False as the pytorch crossentropy applies the softmax internally.
Default = True
seed (int, optional) –
A custom seed for model initialization. It must be a nonnegative number. If None is passed, no custom seed will be set
Default = None
References
[fbcnet]Mane et al. FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface. https://arxiv.org/abs/2104.01233
Example
>>> import selfeeg.models >>> import torch >>> x = torch.randn(4,8,512) >>> mdl = models.FBCNet(2, 8, 512, 128) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 2]) >>> print(torch.isnan(out).sum()) # shoud return 0