Source code for selfeeg.models.layers

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import filtfilt
from selfeeg.augmentation.functional import get_filter_coeff

__all__ = [
    "ConstrainedConv1d",
    "ConstrainedConv2d",
    "ConstrainedDense",
    "DepthwiseConv2d",
    "FilterBank",
    "SeparableConv2d",
]


[docs] class ConstrainedDense(nn.Linear): """ nn.Linear layer with norm constraints. This is a Pytorch implementation of the Dense layer with the possibility of adding a MaxNorm, MinMaxNorm, or a UnitNorm constraint. Most of the parameters are the same as described in torch.nn.Linear help. Parameters ---------- in_features: int Number of input features. out_channels: int Number of output features. bias: bool, optional If True, adds a learnable bias to the output. Default = True device: torch.device or str, optional The torch device. dtype: torch dtype, optional layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. max_norm: float, optional The maximum norm each hidden unit can have. If None no constraint will be added. Default = 2.0 min_norm: float, optional The minimum norm each hidden unit can have. Must be a float lower than max_norm. If given, MinMaxNorm will be applied in the case max_norm is also given. Otherwise, it will be ignored. Default = None axis_norm: Union[int, list, tuple], optional The axis along weights are constrained. It behaves like Keras. So, considering that a Conv2D layer has shape (output_depth, input_depth), set axis to 1 will constrain the weights of each filter tensor of size (input_depth,). Default = 1 minmax_rate: float, optional A constraint for MinMaxNorm setting how weights will be rescaled at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using minmax_rate = 1 will set a strict enforcement of the constraint, while rate<1.0 will slowly rescale layer's hidden units at each step. Default = 1.0 Note ---- To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, set both min_norm and max_norm to 1.0. Example ------- >>> from selfeeg.models import ConstrainedDense >>> import torch >>> x = torch.randn(4,64) >>> mdl = ConstrainedDense(64,32) >>> out = mdl(x) >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=1)) >>> print(out.shape) # shoud return torch.Size([4, 32]) >>> print(torch.isnan(out).sum()) # shoud return 0 >>> print(torch.sum(norms>(1.4+1e-3)).item() == 0) # should return True """ def __init__( self, in_features, out_features, bias=True, device=None, dtype=None, max_norm=2.0, min_norm=None, axis_norm=1, minmax_rate=1.0, ): super(ConstrainedDense, self).__init__(in_features, out_features, bias, device, dtype) # Check that max_norm is valid if max_norm is not None: if max_norm <= 0: raise ValueError("max_norm can't be lower or equal than 0") else: self.max_norm = max_norm else: self.max_norm = max_norm # Check that min_norm is valid if min_norm is not None: if min_norm <= 0: raise ValueError("min_norm can't be lower or equal than 0") else: self.min_norm = min_norm else: self.min_norm = min_norm # If both are given, check that max_norm is bigger than min_norm if (self.min_norm is not None) and (self.max_norm is not None): if self.min_norm > self.max_norm: raise ValueError("max_norm can't be lower than min_norm") # Check that minmax_rate is bigger than 0 if minmax_rate <= 0.0 or minmax_rate > 1.0: raise ValueError("minmax_rate must be in (0,1]") self.minmax_rate = minmax_rate # Check that axis is a valid enter if type(axis_norm) not in [tuple, list, int]: raise TypeError("axis must be a tuple, list, or int") else: if type(axis_norm) == int: if axis_norm < 0 or axis_norm > 1: raise ValueError("Linear has 2 axis. Values must be in 0 or 1") else: for i in axis_norm: if i < 0 or i > 1: raise ValueError("Axis values must be in 0 or 1") self.axis_norm = axis_norm # set the constraint case: # 0 --> no contraint # 1 --> MaxNorm # 2 --> MinMaxNorm # 3 --> UnitNorm # The division is for computational purpose. # MinMaxNorm takes almost twice to execute than other operations. if self.max_norm is not None: if self.min_norm is not None: if self.min_norm == 1 and self.max_norm == 1: self.constraint_type = 3 else: self.constraint_type = 2 else: self.constraint_type = 1 else: self.constraint_type = 0
[docs] @torch.no_grad() def scale_norm(self, eps=1e-9): """ applies the desired constraint on the Layer. It is highly based on the Keras implementation, but here MaxNorm, MinMaxNorm and UnitNorm are all implemented inside this function. """ if self.constraint_type == 1: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = torch.clamp(norms, 0, self.max_norm) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 2: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = ( self.minmax_rate * torch.clamp(norms, self.min_norm, self.max_norm) + (1 - self.minmax_rate) * norms ) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 3: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) self.weight.data /= eps + norms
def forward(self, input): """ :meta private: """ if self.constraint_type != 0: self.scale_norm() return F.linear(input, self.weight, self.bias)
[docs] class ConstrainedConv1d(nn.Conv1d): """ nn.Conv1d layer with norm constraints. This is a Pytorch implementation of the 1D Convolutional layer with the possibility to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along the given axis. Most of the parameters are the same as described in pytorch Conv1d help. Parameters ---------- in_channels: int Number of input channels. out_channels: int Number of output channels. kernel_size: int or tuple Size of the convolving kernel. stride: int or tuple, optional Stride of the convolution. Default = 1 padding: int, tuple or str, optional Padding added to all four sides of the input. This class also accepts the string 'causal', which triggers causal convolution like in Wavenet. Default = 0 dilation: int or tuple, optional Spacing between kernel elements. Default = 1 groups: int, optional Number of blocked connections from input channels to output channels. Default = 1 bias: bool, optional If True, adds a learnable bias to the output. Default = True padding_mode: str, optional Any of 'zeros', 'reflect', 'replicate' or 'circular'. Default = 'zeros' device: torch.device or str, optional The torch device. dtype: torch.dtype, optional Layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. max_norm: float, optional The maximum norm each hidden unit can have. If None no constraint will be added. Default = 2.0 min_norm: float, optional The minimum norm each hidden unit can have. Must be a float lower than max_norm. If given, MinMaxNorm will be applied in the case max_norm is also given. Otherwise, it will be ignored. Default = None axis_norm: Union[int, list, tuple], optional The axis along weights are constrained. It behaves like Keras. So, considering that a Conv1D layer has shape (output_depth, input_depth, length), set axis to [1, 2] will constrain the weights of each filter tensor of size (input_depth, length). Default = [1,2] minmax_rate: float, optional A constraint for MinMaxNorm setting how weights will be rescaled at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using minmax_rate = 1 will set a strict enforcement of the constraint, while rate<1.0 will slowly rescale layer's hidden units at each step. Default = 1.0 Note ---- To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, set both min_norm and max_norm to 1.0. Note ---- When setting ``padding`` to ``"causal"``, padding will be internally changed to an integer equal to ``(kernel_size - 1) * dilation``. Then, during forward, the extra features are removed. This is preferable over F.pad, which can lead to memory allocation or even non-deterministic operations during the backboard pass. Additional information can be found at the following link: https://github.com/pytorch/pytorch/issues/1333 Example ------- >>> from import selfeeg.models import ConstrainedConv1d >>> import torch >>> x = torch.randn(4, 8, 64) >>> mdl = ConstrainedConv1d(8, 16, 15, max_norm = 1.4, min_norm = 0.3) >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) >>> out = mdl(x) >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2])) >>> print(out.shape) # shoud return torch.Size([4, 16, 64]) >>> print(torch.isnan(out).sum()) # shoud return 0 >>> print(torch.sum(norms>(1.4+1e-3)).item() == 0) # should return True """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding="same", dilation=1, groups=1, bias=True, padding_mode="zeros", device=None, dtype=None, max_norm=2.0, min_norm=None, axis_norm=[1, 2], minmax_rate=1.0, ): # Check causal Padding self.pad = padding self.causal_pad = False if isinstance(padding, str): if padding.casefold() == "causal": self.causal_pad = True self.pad = (kernel_size - 1) * dilation super(ConstrainedConv1d, self).__init__( in_channels, out_channels, kernel_size, stride, self.pad, dilation, groups, bias, padding_mode, device, dtype, ) # Check that max_norm is valid if max_norm is not None: if max_norm <= 0: raise ValueError("max_norm can't be lower or equal than 0") else: self.max_norm = max_norm else: self.max_norm = max_norm # Check that min_norm is valid if min_norm is not None: if min_norm <= 0: raise ValueError("min_norm can't be lower or equal than 0") else: self.min_norm = min_norm else: self.min_norm = min_norm # If both are given, check that max_norm is bigger than min_norm if (self.min_norm is not None) and (self.max_norm is not None): if self.min_norm > self.max_norm: raise ValueError("max_norm can't be lower than min_norm") # Check that minmax_rate is bigger than 0 if minmax_rate <= 0.0 or minmax_rate > 1.0: raise ValueError("minmax_rate must be in (0,1]") self.minmax_rate = minmax_rate # Check that axis is a valid enter if type(axis_norm) not in [tuple, list, int]: raise TypeError("axis must be a tuple, list, or int") else: if type(axis_norm) == int: if axis_norm < 0 or axis_norm > 2: raise ValueError("Conv2D has 4 axis. Values must be in [0, 2]") else: for i in axis_norm: if i < 0 or i > 2: raise ValueError("Axis values must be in [0, 2]") self.axis_norm = axis_norm # set the constraint case: # 0 --> no contraint # 1 --> MaxNorm # 2 --> MinMaxNorm # 3 --> UnitNorm # The division is for computational purpose. # MinMaxNorm takes almost twice to execute than other operations. if self.max_norm is not None: if self.min_norm is not None: if self.min_norm == 1 and self.max_norm == 1: self.constraint_type = 3 else: self.constraint_type = 2 else: self.constraint_type = 1 else: self.constraint_type = 0
[docs] @torch.no_grad() def scale_norm(self, eps=1e-9): """ applies the desired constraint on the Layer. It is highly based on the Keras implementation, but here MaxNorm, MinMaxNorm and UnitNorm are all implemented inside this function. """ if self.constraint_type == 1: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = torch.clamp(norms, 0, self.max_norm) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 2: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = ( self.minmax_rate * torch.clamp(norms, self.min_norm, self.max_norm) + (1 - self.minmax_rate) * norms ) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 3: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) self.weight.data /= eps + norms
def forward(self, input): """ :meta private: """ if self.constraint_type != 0: self.scale_norm() if self.causal_pad: return self._conv_forward(input, self.weight, self.bias)[:, :, : -self.pad] else: return self._conv_forward(input, self.weight, self.bias)
[docs] class ConstrainedConv2d(nn.Conv2d): """ nn.conv2d layer with norm constraints. This is a Pytorch implementation of the 2D Convolutional layer with the possibility to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along the given axis. Most of the parameters are the same as described in pytorch Conv2D help. Parameters ---------- in_channels: int Number of input channels. out_channels: int Number of output channels. kernel_size: int or tuple Size of the convolving kernel. stride: int or tuple, optional Stride of the convolution. Default = 1 padding: int, tuple or str, optional Padding added to all four sides of the input. Default = 0 dilation: int or tuple, optional Spacing between kernel elements. Default = 1 groups: int, optional Number of blocked connections from input channels to output channels. Default = 1 bias: bool, optional If True, adds a learnable bias to the output. Default = True padding_mode: str, optional Any of 'zeros', 'reflect', 'replicate' or 'circular'. Default = 'zeros' device: torch.device or str, optional The torch device. dtype: torch.dtype, optional Layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. max_norm: float, optional The maximum norm each hidden unit can have. If None no constraint will be added. Default = 2.0 min_norm: float, optional The minimum norm each hidden unit can have. Must be a float lower than max_norm. If given, MinMaxNorm will be applied in the case max_norm is also given. Otherwise, it will be ignored. Default = None axis_norm: Union[int, list, tuple], optional The axis along weights are constrained. It behaves like Keras. So, considering that a Conv2D layer has shape (output_depth, input_depth, rows, cols), set axis to [1, 2, 3] will constrain the weights of each filter tensor of size (input_depth, rows, cols). Default = [1,2,3] minmax_rate: float, optional A constraint for MinMaxNorm setting how weights will be rescaled at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using minmax_rate = 1 will set a strict enforcement of the constraint, while rate<1.0 will slowly rescale layer's hidden units at each step. Default = 1.0 Note ---- To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, set both min_norm and max_norm to 1.0. Example ------- >>> from import selfeeg.models import ConstrainedConv2d >>> import torch >>> x = torch.randn(4, 1, 8, 64) >>> mdl = ConstrainedConv2d(1, 4, (1, 15), max_norm = 0.5) >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) >>> out = mdl(x) >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2,3])) >>> print(out.shape) # shoud return torch.Size([4, 2, 8, 64]) >>> print(torch.isnan(out).sum()) # shoud return 0 >>> print(torch.sum(norms>(0.5+1e-3)).item() == 0) # should return True """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding="same", dilation=1, groups=1, bias=True, padding_mode="zeros", device=None, dtype=None, max_norm=2.0, min_norm=None, axis_norm=[1, 2, 3], minmax_rate=1.0, ): super(ConstrainedConv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) # Check that max_norm is valid if max_norm is not None: if max_norm <= 0: raise ValueError("max_norm can't be lower or equal than 0") else: self.max_norm = max_norm else: self.max_norm = max_norm # Check that min_norm is valid if min_norm is not None: if min_norm <= 0: raise ValueError("min_norm can't be lower or equal than 0") else: self.min_norm = min_norm else: self.min_norm = min_norm # If both are given, check that max_norm is bigger than min_norm if (self.min_norm is not None) and (self.max_norm is not None): if self.min_norm > self.max_norm: raise ValueError("max_norm can't be lower than min_norm") # Check that minmax_rate is bigger than 0 if minmax_rate <= 0.0 or minmax_rate > 1.0: raise ValueError("minmax_rate must be in (0,1]") self.minmax_rate = minmax_rate # Check that axis is a valid enter if type(axis_norm) not in [tuple, list, int]: raise TypeError("axis must be a tuple, list, or int") else: if type(axis_norm) == int: if axis_norm < 0 or axis_norm > 3: raise ValueError("Conv2D has 4 axis. Values must be in [0, 3]") else: for i in axis_norm: if i < 0 or i > 3: raise ValueError("Axis values must be in [0, 3]") self.axis_norm = axis_norm # set the constraint case: # 0 --> no contraint # 1 --> MaxNorm # 2 --> MinMaxNorm # 3 --> UnitNorm # The division is for computational purpose. # MinMaxNorm takes almost twice to execute than other operations. if self.max_norm is not None: if self.min_norm is not None: if self.min_norm == 1 and self.max_norm == 1: self.constraint_type = 3 else: self.constraint_type = 2 else: self.constraint_type = 1 else: self.constraint_type = 0
[docs] @torch.no_grad() def scale_norm(self, eps=1e-9): """ applies the desired constraint on the Layer. It is highly based on the Keras implementation, but here MaxNorm, MinMaxNorm and UnitNorm are all implemented inside this function. """ if self.constraint_type == 1: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = torch.clamp(norms, 0, self.max_norm) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 2: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = ( self.minmax_rate * torch.clamp(norms, self.min_norm, self.max_norm) + (1 - self.minmax_rate) * norms ) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 3: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) self.weight.data /= eps + norms
def forward(self, input): """ :meta private: """ if self.constraint_type != 0: self.scale_norm() return self._conv_forward(input, self.weight, self.bias)
[docs] class DepthwiseConv2d(nn.Conv2d): """ Depthwise 2D layer with norm constraints. Pytorch implementation of the Depthwise Convolutional layer with the possibility to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along the given axis. Most of the parameters are the same as described in pytorch Conv2D help. Parameters ---------- in_channels: int Number of input channels. depth_multiplier: int The depth multiplier. Output channels will be depth_multiplier*in_channels. kernel_size: int or tuple Size of the convolving kernel. stride: int or tuple, optional Stride of the convolution. Default = 1 padding: int, tuple or str, optional Padding added to all four sides of the input. Default = 0 dilation: int or tuple, optional Spacing between kernel elements. Default = 1 bias: bool, optional If True, adds a learnable bias to the output. Default = True max_norm: float, optional The maximum norm each hidden unit can have. If None no constraint will be added. Default = 2.0 min_norm: float, optional The minimum norm each hidden unit can have. Must be a float lower than max_norm. If given, MinMaxNorm will be applied in the case max_norm is also given. Otherwise, it will be ignored. Default = None axis_norm: Union[int, list, tuple], optional The axis along weights are constrained. It behaves like Keras. So, considering that a Conv2D layer has shape (output_depth, input_depth, rows, cols), set axis to [1, 2, 3] will constrain the weights of each filter tensor of size (input_depth, rows, cols). Default = [1,2,3] minmax_rate: float, optional A constraint for MinMaxNorm setting how weights will be rescaled at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using minmax_rate = 1 will set a strict enforcement of the constraint, while rate<1.0 will slowly rescale layer's hidden units at each step. Default = 1.0 Note ---- To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, set both min_norm and max_norm to 1.0. Example ------- >>> from import selfeeg.models import DepthwiseConv2d >>> import torch >>> x = torch.randn(4,1,8,64) >>> mdl = DepthwiseConv2d(1, 2, (1, 15), max_norm = 0.5) >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) >>> out = mdl(x) >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2,3])) >>> print(out.shape) # shoud return torch.Size([4, 2, 8, 64]) >>> print(torch.isnan(out).sum()) # shoud return 0 >>> print(torch.sum(norms>(0.5+1e-3)).item() == 0) # should return True """ def __init__( self, in_channels, depth_multiplier, kernel_size, stride=1, padding="same", dilation=1, bias=False, max_norm=2.0, min_norm=None, axis_norm=[1, 2, 3], minmax_rate=1.0, ): super(DepthwiseConv2d, self).__init__( in_channels, depth_multiplier * in_channels, kernel_size, groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias, ) # Check that max_norm is valid if max_norm is not None: if max_norm <= 0: raise ValueError("max_norm can't be lower or equal than 0") else: self.max_norm = max_norm else: self.max_norm = max_norm # Check that min_norm is valid if min_norm is not None: if min_norm <= 0: raise ValueError("min_norm can't be lower or equal than 0") else: self.min_norm = min_norm else: self.min_norm = min_norm # If both are given, check that max_norm is bigger than min_norm if (self.min_norm is not None) and (self.max_norm is not None): if self.min_norm > self.max_norm: raise ValueError("max_norm can't be lower than min_norm") # Check that minmax_rate is bigger than 0 if minmax_rate <= 0.0 or minmax_rate > 1.0: raise ValueError("minmax_rate must be in (0,1]") self.minmax_rate = minmax_rate # Check that axis is a valid enter if type(axis_norm) not in [tuple, list, int]: raise TypeError("axis must be a tuple, list, or int") else: if type(axis_norm) == int: if axis_norm < 0 or axis_norm > 3: raise ValueError("Conv2D has 4 axis. Values must be in [0, 3]") else: for i in axis_norm: if i < 0 or i > 3: raise ValueError("Axis values must be in [0, 3]") self.axis_norm = axis_norm # set the constraint case: # 0 --> no contraint # 1 --> MaxNorm # 2 --> MinMaxNorm # 3 --> UnitNorm # The division is for computational purpose. # MinMaxNorm takes almost twice to execute than other operations. if self.max_norm is not None: if self.min_norm is not None: if self.min_norm == 1 and self.max_norm == 1: self.constraint_type = 3 else: self.constraint_type = 2 else: self.constraint_type = 1 else: self.constraint_type = 0
[docs] @torch.no_grad() def scale_norm(self, eps=1e-9): """ applies the desired constraint on the Layer. It is highly based on the Keras implementation, but here MaxNorm, MinMaxNorm and UnitNorm are all implemented inside this function. """ if self.constraint_type == 1: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = torch.clamp(norms, 0, self.max_norm) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 2: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) desired = ( self.minmax_rate * torch.clamp(norms, self.min_norm, self.max_norm) + (1 - self.minmax_rate) * norms ) self.weight.data *= desired / (eps + norms) elif self.constraint_type == 3: norms = torch.sqrt( torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) ) self.weight.data /= eps + norms
def forward(self, input): """ :meta private: """ if self.constraint_type != 0: self.scale_norm() return self._conv_forward(input, self.weight, self.bias)
[docs] class SeparableConv2d(nn.Module): """ Separable Convolutional layer with norm constraints. Pytorch implementation of the Separable Convolutional layer with the possibility of adding a norm constraint on both the depthwise and pointwise filters (feature) dimension. The layer applies first a depthwise conv2d, then a pointwise conv2d (kernel size = 1) Most of the parameters are the same as described in pytorch conv2D help. Parameters ---------- in_channels: int Number of input channels. out_channels: int Number of output channels. kernel_size: int or tuple Size of the convolving kernel stride: int or tuple, optional Stride of the convolution. Default = 1 padding: int, tuple or str, optional Padding added to all four sides of the input. Default = 0 dilation: int or tuple, optional Spacing between kernel elements. Default = 1 bias: bool, optional If True, adds a learnable bias to the output. Default = True depth_multiplier: int, optional The depth multiplier of the depthwise block. Default = 1 depth_max_norm: float, optional The maximum norm each hidden unit in the depthwise layer can have. If None no constraint will be added. Default = None depth_min_norm: float, optional The minimum norm each hidden unit in the depthwise layer can have. Must be a float lower than max_norm. If given, MinMaxNorm will be applied in the case max_norm is also given. Otherwise, it will be ignored. Default = None depth_minmax_rate: float, optional A constraint for depthwise's MinMaxNorm setting how weights will be rescaled at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using minmax_rate = 1 will set a strict enforcement of the constraint, while rate<1.0 will slowly rescale layer's hidden units at each step. Default = 1.0 axis_norm: Union[int, list, tuple], optional The axis along weights are constrained. It behaves like Keras. So, considering that a Conv2D layer has shape (output_depth, input_depth), set axis to 1 will constrain the weights of each filter tensor of size (input_depth,). Default = 1 point_max_norm: float, optional Same as depth_max_norm, but applied to the pointwise Convolutional layer. Default = None point_min_norm: float, optional Same as depth_min_norm, but applied to the pointwise Convolutional layer. Default = None point_minmax_rate: float, optional Same as depth_minmax_rate, but applied to the pointwise Convolutional layer. Default = 1.0 Example ------- >>> from selfeeg.models import SeparableConv2d >>> import torch >>> x = torch.randn(4, 1, 8, 64) >>> mdl = SeparableConv2d(1,4, (1,15), depth_multiplier=4) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 4, 8, 64]) >>> print(torch.isnan(out).sum()) # shoud return 0 """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding="same", dilation=1, bias=False, depth_multiplier=1, depth_max_norm=None, depth_min_norm=None, depth_minmax_rate=1.0, point_max_norm=None, point_min_norm=None, point_minmax_rate=1.0, axis_norm=[1, 2, 3], ): super(SeparableConv2d, self).__init__() self.depthwise = DepthwiseConv2d( in_channels, depth_multiplier, kernel_size, stride, padding, dilation, bias, max_norm=depth_max_norm, min_norm=depth_min_norm, axis_norm=axis_norm, minmax_rate=depth_minmax_rate, ) self.pointwise = ConstrainedConv2d( in_channels * depth_multiplier, out_channels, kernel_size=1, bias=bias, max_norm=point_max_norm, min_norm=point_min_norm, axis_norm=axis_norm, minmax_rate=point_minmax_rate, ) def forward(self, input): """ :meta private: """ out = self.depthwise(input) out = self.pointwise(out) return out
[docs] class FilterBank(nn.Module): """ Filter bank layer of the FBCNet model. Pytorch implementation of the Filter Bank layer used in the FBCNet model The layer applies a sequence of filters with different bandwidth; then, it concatenates each generated signal in the convolutional channel dimension. The expected **input** is a **3D tensor** with size (Batch x Channels x Samples). The generated output is a a **4D tensor** with size (Batch x Filters x Channels x Samples). See the original paper for more info Parameters ---------- Fs: int or float The EEG sampling rate. Bands: int, optional The number of filters to apply to the original signal. Default = 9 Range: int or float, optional The passband of each filter, given in Hz. Default = 4 Type: str, optional The type of filter to use. Allowed arguments are the same as described in the `get_filter_coeff()` function of the `selfeeg.augmentation.functional` submodule ('butter', 'ellip', 'cheby1', 'cheby2') Default = 'cheby1' StopRipple: int or float, optional Ripple at stopband in decibel. Default = 30 PassRipple: int or float, optional Ripple at passband in decibel. Default = 3 RangeTol: int or float, optional The filter transition bandwidth in Hz. Default = 2 SkipFirst: bool, optional If True, skips the first filter with passband equal to [0, Range] Hz. The number of filters specified in Bands will still be preserved. Default = True Warnings -------- Do not use too strict filter specs (e.g. narrow transition bandwidths, high stopband ripple, low passband ripple) as this can generate undesired outputs (nan values or incredibly high values). Example ------- >>> from selfeeg.models import FilterBank >>> import torch >>> x = torch.randn(4, 8, 512) >>> mdl = FilterBank(128) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 9, 8, 512]) >>> print(torch.isnan(out).sum()) # shoud return 0 """ def __init__( self, Fs, Bands=9, Range=4, Type="Cheby2", StopRipple=30, PassRipple=3, RangeTol=2, SkipFirst=True, ): self.Bands = Bands super(FilterBank, self).__init__() A_list = [None] * Bands B_list = [None] * Bands max_b, max_a = 0, 0 NStart = 1 if SkipFirst else 0 NEnd = Bands + 1 if SkipFirst else 0 for i in range(NStart, NEnd): if i == 0: b, a = get_filter_coeff( Fs, Range, Range + RangeTol, btype="lowpass", filter_type=Type, rp=PassRipple, rs=StopRipple, ) elif (Range * (i + 1)) + RangeTol < Fs / 2: b, a = get_filter_coeff( Fs, [Range * i, Range * (i + 1)], [Range * i - RangeTol, Range * (i + 1) + RangeTol], btype="bandpass", filter_type=Type, rp=PassRipple, rs=StopRipple, ) else: b, a = get_filter_coeff( Fs, Range * i, Range * i - RangeTol, btype="highpass", filter_type=Type, rp=PassRipple, rs=StopRipple, ) if SkipFirst: A_list[i - 1] = a B_list[i - 1] = b else: A_list[i - 1] = a B_list[i - 1] = b max_b = max_b if b.shape[0] < max_b else b.shape[0] max_a = max_a if a.shape[0] < max_a else a.shape[0] self.A_coeffs = torch.zeros(Bands, max_a) self.B_coeffs = torch.zeros(Bands, max_b) for i in range(Bands): self.A_coeffs[i, : A_list[i].shape[0]] = torch.from_numpy(A_list[i]) self.B_coeffs[i, : B_list[i].shape[0]] = torch.from_numpy(B_list[i]) self.A_coeffs = self.A_coeffs.to(dtype=torch.float32) self.B_coeffs = self.B_coeffs.to(dtype=torch.float32) self.A_coeffs = torch.nn.Parameter(self.A_coeffs, requires_grad=False) self.B_coeffs = torch.nn.Parameter(self.B_coeffs, requires_grad=False) def forward(self, x): """ :meta private: """ x = x.unsqueeze(2).repeat(1, 1, self.Bands, 1) x = filtfilt(x, self.A_coeffs, self.B_coeffs, clamp=False) x = torch.permute(x, [0, 2, 1, 3]) return x