from __future__ import annotations
import copy
import inspect
import random
from typing import Any, Dict
import numpy as np
from numpy.typing import ArrayLike
__all__ = [
"CircularAug",
"DynamicSingleAug",
"RandomAug",
"SequentialAug",
"StaticSingleAug",
]
[docs]
class StaticSingleAug:
"""
Single static augmentation with preset arguments.
``StaticSingleAug`` performs a single data augmentation
where the optional arguments are previously set and given during initialization.
No random choice of the arguments is performed. The class accepts multiple set of
optional arguments. In this case they are called individually at each class
call, in a circular manner. This means that the first call uses the first set of
arguments, the second will use the second set of arguments, and so on.
When the last set of arguments is used, the class will restart from the
first set of arguments.
To perform an augmentation, simply call the instantiated class
(see provided example or check the introductory notebook).
Parameters
----------
augmentation: function
The augmentation function to apply. It can be a custom function,
but the first argument must be the element to augment.
arguments: list, dict, list[list or dict], optional
The set of arguments to pass to the augmentation function. It can be:
1. None. In this case the default parameters of the function are used.
2. list. In this case the function is called with the sintax
``augmentation(x, *arguments)``
3. dict. In this case the function is called with the sintax
``augmentation(x, **arguments)``
4. a list of dicts or lists. This is a particular case where multiple
combinations of arguments are given. Each element of the list must be a
list or a dict with the specific argument combination.
Every time ``perform_augmentation`` is called, one of the given
combinations is used to perform the data augmentation. The list
is followed sequentially with repetition, meaning that the first call
uses the first set of arguments of the list, the second call uses the
second set of arguments, and so on. When the last element of the list
is used, the function will restart scrolling the given list.
Default = None
Methods
-------
perform_augmentation(X: ArrayLike)
Apply the augmentation with the given arguments.
__call__() will call this method.
Example
-------
>>> import selfeeg.augmentation as aug
>>> import torch
>>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> Fs = 64
>>> Aug_eye = aug.StaticSingleAug(
... aug.add_eeg_artifact,
... [{'Fs': Fs, 'artifact': 'eye', 'amplitude': 0.5, 'batch_equal': False},
... {'Fs': Fs, 'artifact': 'eye', 'amplitude': 1.0, 'batch_equal': False}]
... )
>>> BatchEEGaug1 = Aug_eye(BatchEEG)
>>> BatchEEGaug2 = Aug_eye(BatchEEG)
plot the augmentations (require matplotlib to be installed)
>>> import matplotlib.pyplot as plt
>>> plt.style.use('seaborn-v0_8-white')
>>> plt.rcParams['figure.figsize'] = (15.0, 6.0)
>>> plt.plot(BatchEEG[0,0],linewidth=2.5)
>>> plt.plot(BatchEEGaug1[0,0])
>>> plt.plot(BatchEEGaug2[0,0])
>>> plt.tick_params(axis='both', which='major', labelsize=12)
>>> plt.title('Example of addition of EEG eye blink artifact', fontsize=15)
>>> plt.legend(['original sample', 'augmented sample amp 0.5',
... 'augmented sample amp 1'])
>>> plt.show()
"""
def __init__(
self, augmentation: "function", arguments: list or dict or list[list or dict] = None
):
if not (inspect.isfunction(augmentation) or inspect.isbuiltin(augmentation)):
raise ValueError("augmentation must be a function to call")
else:
self.augmentation = augmentation
self.arguments = arguments
self.counter = 0
self.maxcounter = 0
self.multipleStaticArguments = False
if arguments != None:
if all(isinstance(i, list) or isinstance(i, dict) for i in arguments):
self.multipleStaticArguments = True
self.maxcounter = len(arguments)
def __call__(self, X):
return self.perform_augmentation(X)
[docs]
class DynamicSingleAug:
"""
Single augmentation with randomly chosen arguments.
``DynamicSingleAug`` performs a single data augmentation
where the optional arguments are chosen at random from a given discrete or
continuous range of values. Random choice of the arguments is performed
at each call.
To perform an augmentation, simply call the instantiated class
(see provided example or check the introductory notebook)
Parameters
----------
augmentation: function
The augmentation function to apply. It can be a custom function,
but the first argument must be the element to augment.
discrete_arg: dict, optional
A dictionary specifying arguments whose value must be chosen
within a discrete set. The dict must have:
- keys as string with the name of one of the optional arguments
- values as lists of elements to be randomly chosen.
Single elements are allowed if a specific value for an argument
needs to be set. In this case it is not mandatory to give it as list,
as automatic conversion will be performed internally. In other words,
a key-value pair given as ``{"arg": value}`` is allowed, since the
conversion to ``{"arg": [value]}`` is automatically performed.
Default = None
range_arg: dict, optional
A dictionary specifying arguments whose value must be chosen within
a continuous range.
The dict must have:
- keys as string with the name of one of the optional arguments
- values as two element lists specifying the range of values
where to randomly select the argument value.
Default = None
range_type: dict or list, optional
A dictionary or a list specifying if values in range_arg must be given
to the augmentation function as integers. If given as a dict, keys must
be the same as the one of range_arg argument. If given as a list, the
length must be the same of range_arg.
In particular:
1. if range_type is a **dict**:
- keys must be those in range_arg
- values must be single element specifying if the argument must be
an integer. In this case, use a **boolean True** or a **string
'int'** to specify if the argument must be converted to an integer.
2. if range_arg is a **list**:
- values must be set as the values in the dict. The order is the
one used when iterating along the range_arg dict.
3. if **None** is given, a list of True with length equal to range_arg
is automatically created, since int arguments are more compatible
compared to float ones.
Default = None
Methods
-------
perform_augmentation(x: ArrayLike)
Apply the augmentation with the given arguments.
__call__() will call this method.
Note
----
At least one of **discrete_arg** or **range_arg** arguments must be given,
the class simply suggests to use ``StaticSingleAug``.
Example
-------
>>> import selfeeg.augmentation as aug
>>> import torch
>>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> Aug_warp = aug.DynamicSingleAug(
... aug.warp_signal,
... discrete_arg = {'batch_equal': [True, False]},
... range_arg= {'segments': [5,15], 'stretch_strength': [1.5,2.5],
... 'squeeze_strength': [0.4,2/3]},
... range_type={'segments': True, 'stretch_strength': False,
... 'squeeze_strength': False}
... )
>>> BatchEEGaug1 = Aug_warp(BatchEEG)
>>> BatchEEGaug2 = Aug_warp(BatchEEG)
plot the augmentations (require matplotlib to be installed)
>>> import matplotlib.pyplot as plt
>>> plt.style.use('seaborn-v0_8-white')
>>> plt.rcParams['figure.figsize'] = (15.0, 6.0)
>>> plt.plot(BatchEEG[0,0],linewidth=2.5)
>>> plt.plot(BatchEEGaug1[0,0])
>>> plt.plot(BatchEEGaug2[0,0])
>>> plt.tick_params(axis='both', which='major', labelsize=12)
>>> plt.title('Example of dynamic aug with warp augmentation', fontsize=15)
>>> plt.legend(['original sample', 'augmented sample 1', 'augmented sample 2'])
>>> plt.show()
"""
def __init__(
self,
augmentation,
discrete_arg: Dict[str, Any] = None,
range_arg: Dict[str, list[int or float, int or float]] = None,
range_type: Dict[str, str or bool] or list[str or bool] = None,
):
# set augmentation function
if not (inspect.isfunction(augmentation) or inspect.isbuiltin(augmentation)):
raise ValueError("augmentation must be a function to call")
else:
self.augmentation = augmentation
# get function argument name
self.argnames = inspect.getfullargspec(augmentation)[0][1:]
# check if given discrete_arg keys are actually augmentation arguments
self.discrete_arg = None
if discrete_arg != None:
if isinstance(discrete_arg, dict):
if all(i in self.argnames for i in discrete_arg):
self.discrete_arg = discrete_arg
else:
raise ValueError(
"keys of discrete_arg argument must be the "
"argument of the augmentation fun"
)
else:
raise ValueError("discrete_arg must be a dictionary")
# check if given range_arg keys are actually augmentation arguments
# also check if values are two element list
self.range_arg = None
if range_arg != None:
if isinstance(range_arg, dict):
if all(i in self.argnames for i in range_arg):
if all((isinstance(i, list) and len(i) == 2) for i in range_arg.values()):
self.range_arg = range_arg
else:
raise ValueError(
"range_arg values must be a length 2 " "list with min and max range"
)
else:
raise ValueError(
"keys of range_arg argument must be the "
"argument of the augmentation function."
)
for i in range_arg:
if not (isinstance(range_arg[i], list)):
range_arg[i] = [range_arg[i]]
else:
raise ValueError("range_arg must be a dictionary")
# check if range_types keys are the same as range_args
self.range_type = None
if range_type != None:
if isinstance(range_type, dict):
if range_type.keys() == range_arg.keys():
self.range_type = range_type
else:
raise ValueError("keys of range_type must be the same as range_arg")
elif isinstance(range_type, list):
if len(range_type) == len(self.range_arg):
self.range_type = range_type
else:
raise ValueError("range_type must have the same length as range_args")
else:
raise ValueError("discrete_arg must be a dictionary or a list")
else:
if self.range_arg != None:
self.range_type = [True] * len(self.range_arg)
self.is_range_type_dict = True if isinstance(self.range_type, dict) else False
self.given_arg = list(self.discrete_arg) if self.discrete_arg != None else []
self.given_arg += list(self.range_arg) if self.range_arg != None else []
def __call__(self, X):
return self.perform_augmentation(X)
[docs]
class SequentialAug:
"""
Multiple augmentations applied sequentially.
``SequentialAug`` applies a sequence of augmentations in a specified order.
No random choice between the given list of augmentation is performed, just
pure call of all the augmentations in the specified order.
To perform an augmentation, simply call the instantiated class
(see provided example or check the introductory notebook)
Parameters
----------
*augmentations: "callable objects"
The sequence of augmentations to apply at each call.
It can be any callable object, but the first argument to pass must be
the element to augment. It is suggested to give a sequence of
``StaticSingleAug`` or ``DynamicSingleAug`` instantiations.
Note
----
If you provide an augmentation implemented outside of this this library, be
sure that the function will return a single output with the element to pass
to the next augmentation function of the list.
Note
----
The function will automatically handle RandomAug instances with return_index
set to True. In this case, an internal deepcopy with return_index set to false
will be automatically created.
Methods
-------
perform_augmentation(X: ArrayLike)
Apply the augmentations with the given arguments and specified order.
__call__() will call this method.
Example
-------
>>> import selfeeg.augmentation as aug
>>> import torch
>>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> Aug_eye = aug.StaticSingleAug(aug.add_eeg_artifact,{'Fs': 64, 'artifact': 'eye', 'amplitude': 0.5})
>>> Aug_warp = aug.DynamicSingleAug(
... aug.warp_signal,
... discrete_arg = {'batch_equal': [True, False]},
... range_arg= {'segments': [5,15], 'stretch_strength': [1.5,1.8],
... 'squeeze_strength': [0.5,2/3]},
... range_type={'segments': True, 'stretch_strength': False,
... 'squeeze_strength': False}
... )
>>> Sequence1= aug.SequentialAug(Aug_eye, Aug_warp)
>>> BatchEEGaug1 = Sequence1(BatchEEG)
>>> BatchEEGaug2 = Sequence1(BatchEEG)
plot the augmentations (require matplotlib to be installed)
>>> import matplotlib.pyplot as plt
>>> plt.style.use('seaborn-v0_8-white')
>>> plt.rcParams['figure.figsize'] = (15.0, 6.0)
>>> plt.plot(BatchEEG[0,0],linewidth=2.5)
>>> plt.plot(BatchEEGaug1[0,0])
>>> plt.plot(BatchEEGaug2[0,0])
>>> plt.tick_params(axis='both', which='major', labelsize=12)
>>> plt.title(
... 'Sequential Aug with eye blink artifact and warp augmentation',
... fontsize=15)
>>> plt.legend(['original sample', 'augmented sample 1', 'augmented sample 2'])
>>> plt.show()
"""
def __init__(self, *augmentations):
self.augs = [item for item in augmentations]
self._search_for_random_aug_with_index()
def __call__(self, X):
return self.perform_augmentation(X)
def _search_for_random_aug_with_index(self):
for idx, item in enumerate(self.augs):
if isinstance(item, RandomAug):
if self.augs[idx].return_index == True:
self.augs[idx] = copy.deepcopy(item)
self.augs[idx].return_index = False
elif isinstance(item, SequentialAug):
item._search_for_random_aug_with_index()
[docs]
class CircularAug:
"""
Single Augmenter called sequentially from a list, following a circular order.
``CircularAug`` calls an Augmenter from a given sequence following the order.
Augmenters are called circularly. This means that the first call uses the first
Augmenter from the input list, the second call will use the second, and so on.
When the last Augmenter is called, the class will restart from the
first one.
To perform an augmentation, simply call the instantiated class
(see provided example or check the introductory notebook)
Parameters
----------
*augmentations: "callable objects"
The list of augmentations to apply at each call.
It can be any callable object, but the first argument to pass must be
the element to augment. It is suggested to give a sequence of
``StaticSingleAug`` or ``DynamicSingleAug`` instantiations.
Note
----
The function will automatically handle RandomAug instances with return_index
set to True. In this case, an internal deepcopy with return_index set to false
will be automatically created.
Methods
-------
perform_augmentation(X: ArrayLike)
Apply the augmentations with the given arguments and specified order.
__call__() will call this method.
Example
-------
>>> import selfeeg.augmentation as aug
>>> import torch
>>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> Aug_eye = aug.StaticSingleAug(
... aug.add_eeg_artifact,{'Fs': 64, 'artifact': 'eye', 'amplitude': 0.5})
>>> Circular = aug.CircularAug(Aug_eye, aug.identity)
>>> EEGeye = Circular(BatchEEG)
>>> EEGid = Circular(BatchEEG)
"""
def __init__(self, *augmentations):
self.augs = [item for item in augmentations]
self._augcnt = 0
self._augnumber = len(self.augs)
self._search_for_random_aug_with_index()
def __call__(self, X):
return self.perform_augmentation(X)
def _update_counter(self):
self._augcnt += 1
if self._augcnt == self._augnumber:
self._augcnt = 0
def _search_for_random_aug_with_index(self):
for idx, item in enumerate(self.augs):
if isinstance(item, RandomAug):
if self.augs[idx].return_index == True:
self.augs[idx] = copy.deepcopy(item)
self.augs[idx].return_index = False
elif isinstance(item, SequentialAug):
item._search_for_random_aug_with_index()
[docs]
class RandomAug:
"""
Random augmentation chosen from a given set.
``RandomAug`` applies an augmentations selected randomly from a given set.
To perform an augmentation, simply call the instantiated class
(see provided example or check the introductory notebook)
Parameters
----------
*augmentations: "callable objects"
The set of augmentations to randomly choose at each call.
It can be any callable object, but the first arguments to pass must
be the ArrayLike object to augment. It is suggested to give a set of
``StaticSingleAug`` or ``DynamicSingleAug`` instantiations.
p: 1D ArrayLike, optional
A 1D array or list with the weights associated to each augmentation
(higher the weight, higher the frequency of choosing an augmentation
of the list). Elements of p must be in the same order as the given
augmentations.
If given, p will be scaled so to have sum 1 (so you can give any value).
If not given, all augmentations will be chosen with equal probability.
Default = None
return_index: bool, optional
Whether to return an index identifying the selected augmentation or not.
The index is simply the output of random.choice function used to select
the augmentation from the given list. Indeces follow the augmentation order
given during class instantiation.
Default = False
Methods
-------
perform_augmentation(X: ArrayLike)
Apply a random augmentation from the given list of augmenters.
__call__() will call this method.
Example
-------
>>> import selfeeg.augmentation as aug
>>> import numpy as np
>>> import torch
>>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> Aug_eye = aug.StaticSingleAug(
... aug.add_eeg_artifact,{'Fs': Fs, 'artifact': 'eye', 'amplitude': 0.5})
>>> Aug_warp = aug.DynamicSingleAug(
... aug.warp_signal,
... discrete_arg = {'batch_equal': [True, False]},
... range_arg= {'segments': [5,15], 'stretch_strength': [1.5,1.8],
... 'squeeze_strength': [0.5,2/3]},
... range_type={'segments': True, 'stretch_strength': False,
... 'squeeze_strength': False}
... )
>>> Sequence2= aug.RandomAug(Aug_eye, Aug_warp, p=[0.7, 0.3])
>>> BatchEEGaug1 = Sequence2(BatchEEG)
>>> BatchEEGaug2 = Sequence2(BatchEEG)
plot the augmentations (require matplotlib to be installed)
>>> # simulate 10000 augmentations calls this line is used in RandomAug
>>> # to choose the index of the list of augmentations to call.
>>> # Note that the size argument has been added to make computation faster, in
>>> # the class only 1 value is returned
>>> import matplotlib.pyplot as plt
>>> idx=np.random.choice(Sequence2.nprange_, size= 1000, p=Sequence2.p)
>>> counts=[(1000-len(np.nonzero(idx)[0]))/1000, len(np.nonzero(idx)[0])/1000]
>>> plt.subplot(1,3,(1,2))
>>> plt.plot(BatchEEG[0,0],linewidth=2.5)
>>> plt.plot(BatchEEGaug1[0,0])
>>> plt.plot(BatchEEGaug2[0,0])
>>> plt.tick_params(axis='both', which='major', labelsize=12)
>>> plt.title(
... 'Example of Random aug between eye blinking artifact and warp',
... fontsize=15)
>>> plt.legend(
... ['original sample', 'augmented sample 1', 'augmented sample 2'],
... loc='upper left')
>>> plt.subplot(1,3,3)
>>> plt.bar(['eye blinking', 'warp'],counts)
>>> plt.tick_params(axis='both', which='major', labelsize=12)
>>> plt.title('barplot of chosen augmentations', fontsize=15)
>>> plt.ylabel('probability',fontsize=12)
>>> plt.show()
"""
def __init__(self, *augmentations, p=None, return_index=False):
self.augs = [item for item in augmentations]
self.N = len(self.augs)
self.p = p
self.return_index = return_index
if p is not None:
if len(p) != self.N:
raise ValueError("length of p does not match the number of augmentations")
self.p = np.array(p) + 0.0
self.p /= np.sum(p)
self.nprange_ = np.arange(0, self.N)
def __call__(self, X):
return self.perform_augmentation(X)