Build a self-supervised learning pipeline
This notebook will explain how to use the selfeeg library to build a self-supervised learning pipeline.
To summarize, typical steps include:
-
define the pretraining model (and optional training elements)
-
define the fine-tuning model (and optional training elements)
To better understand how the dataloading and augmentations module work, check the respective introductory notebooks
First, let’s import all the packages necessary to run this notebook.
WARNING
to run this notebook you will also need matplotlib, which are not listed in the main dependecies of the selfeeg library. Be sure to install them in your environment.
[1]:
# IMPORT BASE PACKAGES
import os
import random
import pickle
import copy
import sys
sys.path.append('..') # needed if you run this inside the selfeeg/doc folder
import selfeeg
import selfeeg.augmentation as aug
import selfeeg.dataloading as dl
# IMPORT CLASSICAL PACKAGES
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Draw figures inline with this notebook
%matplotlib inline
# IMPORT TORCH
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# set seeds for reproducibility
seed = 1234
random.seed( seed )
np.random.seed( seed )
torch.manual_seed( seed )
plt.style.use('seaborn-v0_8-white')
plt.rcParams['figure.figsize'] = (15.0, 6.0)
Pretraining Phase
Create simulated data
Similarly to the dataloading introductory notebook. We will create a dataset with simulated EEG samples.
EEG files will be stored in the format "{dataset_id}_{subject_id}_{session_id}_{trial_id}.pickle". Also, EEG files will contain a dict with keys:
data: the EEG with 16 Channels , random legnth from 512 to 1024, sampling rate 128 Hz
label: a binary label with 0 = normal EEG and 1 = abnormal (epileptic) EEG. Class ratio is 80% normal and 20% abnormal.
EEG files will be generated using an order-1 AutoRegressive model. Coefficients were calculated using a channel acquisition from two real EEGs, one for each category.
The following cell will create a folder with simulated 1000 EEGs coming from:
5 datasets (ID from 1 to 5);
40 subjects per dataset (ID from 1 to 40)
5 session per subject (ID from 1 to 5)
[2]:
classes = selfeeg.utils.create_dataset(p=0.6,return_labels=True)
Define dataloaders
Let’s assume we want to pretrain our model with the first four datasets, and evaluate it on the fifth.
NOTE 1
In this phase, it is suggested to put the fine-tuning data in the test set. In this way, it will be easier to extract the subtables with only the fine-tuning samples.
Since each EEG will have a different length, we will use 2s windows with 10% overlap
[4]:
# define file path, sampling rate, window length, overlap percentage, workers and batch size
eegpath = 'Simulated_EEG'
Chan = 8
freq = 128
window = 2
overlap = 0.1
workers = 0
batchsize = 16
# define custom loading function
def loadEEG(path, return_label=False):
with open(path, 'rb') as eegfile:
EEG = pickle.load(eegfile)
x = EEG['data']
y = EEG['label']
if return_label:
return x, y
else:
return x
# calculate dataset length
EEGlen = dl.get_eeg_partition_number(
eegpath, freq, window, overlap, file_format='*.pickle',
load_function=loadEEG, verbose=True
)
# split dataset
EEGsplit= dl.get_eeg_split_table(
partition_table=EEGlen, val_ratio= 0.1, stratified=True, labels=classes,
test_data_id=[5], split_tolerance=0.001, perseverance=10000
)
# check split
dl.check_split(EEGlen, EEGsplit, classes)
# define training dataloader
trainset = dl.EEGDataset(EEGlen, EEGsplit, [freq, window, overlap], load_function=loadEEG)
trainsampler = dl.EEGSampler(trainset, batchsize, workers)
trainloader = DataLoader(
dataset = trainset, batch_size=batchsize, sampler=trainsampler, num_workers=workers)
# define validation dataloader
valset = dl.EEGDataset(
EEGlen, EEGsplit, [freq, window, overlap], 'validation', load_function=loadEEG)
valloader = DataLoader(dataset = valset, batch_size= batchsize, shuffle=False, num_workers=0)
extracting EEG samples: 100%|███████████████████| 1000/1000 [00:00<00:00, 1930.72 files/s]
Concluded extraction of repository length with the following specific:
window ==> 2.00 s
overlap ==> 10.00 %
sampling rate ==> 128.00 Hz
-----------------------------
dataset length ==> 3203
train ratio: 0.70
validation ratio: 0.10
test ratio: 0.20
train labels ratio: 0.0=0.589, 1.0=0.411,
val labels ratio: 0.0=0.597, 1.0=0.403,
test labels ratio: 0.0=0.620, 1.0=0.380,
Define the data augmenter
Now we need to define an augmenter. To keep things simple, we define an augmenter which combines:
the addition of some noise or channel lost from
add_band_noiseormaskingthe
warporcrop_and_resizeaugmentationa final rescale of the range [-500, 500] uV in [-1, 1] with soft clipping with horizontal asintote of 1.5
This is similar to the augmentation proposed in the augmentation module introductory book
[5]:
# DEFINE AUGMENTER
# First block: noise addition
AUG_band = aug.DynamicSingleAug(
aug.add_band_noise,
discrete_arg={
'bandwidth': ["delta", "theta", "alpha", "beta", (30,49) ],
'samplerate': freq,
'noise_range': 0.5
}
)
AUG_mask = aug.DynamicSingleAug(
aug.masking,
discrete_arg = {'mask_number': [1,2,3,4], 'masked_ratio': 0.25}
)
Block1 = aug.RandomAug( AUG_band, AUG_mask, p=[0.7, 0.3])
# second block: warp or crop and resize
AUG_crop = aug.DynamicSingleAug(
aug.crop_and_resize,
discrete_arg={'batch_equal': False},
range_arg ={'N_cut': [1, 4], 'segments': [10,15]},
range_type =[True, True]
)
AUG_warp = aug.DynamicSingleAug(
aug.warp_signal,
discrete_arg = {'batch_equal': [True, False]},
range_arg= {'segments': [5,10],
'stretch_strength': [1.75,2.25],
'squeeze_strength': [0.45,0.55]},
range_type=[True, False, False]
)
Block2 = aug.RandomAug( AUG_crop, AUG_warp)
# third block: rescale
Block3 = lambda x: selfeeg.utils.scale_range_soft_clip(x, 500, 1.2, 'uV', True)
# FINAL AUGMENTER: SEQUENCE OF THE THREE RANDOM LISTS
Augmenter = aug.SequentialAug(Block1, Block2, Block3)
[6]:
# visualize a random data augmentation
Sample = trainset.__getitem__(random.randint(0,len(trainset)))
t = np.linspace(0, Sample.shape[1]-1, Sample.shape[1])/freq
SampleAug = Augmenter(Sample)
RandChan= random.randint(0,Chan-1)
fig, ax1 = plt.subplots()
color = 'tab:blue'
ax1.set_xlabel('time [s]', fontsize=15)
ax1.set_ylabel('[uV]', color=color, fontsize=15)
ax1.plot(t, Sample[RandChan,:], color=color)
ax1.tick_params(axis='y', labelcolor=color, labelsize=15)
ax1.tick_params(axis='x', labelsize=15)
ax2 = ax1.twinx()
color = 'tab:orange'
ax2.set_ylabel('[ ]', color=color, fontsize=15) # we already handled the x-label with ax1
ax2.plot(t, SampleAug[RandChan,:],color=color, linewidth=2.5)
ax2.tick_params(axis='y', labelsize=15, labelcolor=color)
plt.title('Same random channel from one sample: augmented version', fontsize=20)
fig.tight_layout()
plt.show()
Define pretraining model and other training objects
Now we need to define the pretraining model. To do that, one must:
instantiate an nn.Module defining the encoder (backbone)
instantiate the right SSL module, giving the encoder and the network head’s spec.
For now, let’s use a simple EEGNet with default parameters, and SimCLR as the SSL algorithm.
NOTE 1
each model in the models module have an extra class with only the encoder with the name modelnameEncoder (e.g., EEGNetEncoder). This will make model creation much easier.
NOTE 2
each model in the ssl module can accept a list or a nn.Module to create the network head. In case of a list, the head will be a sequence of dense layer with input and output size equal to the values of the list. Batchnorm and activation are based on the original works.
WARNING
Remember to check if the encoder output size matches the head input size. All modules in the ssl class doesn’t check that.
[7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Encoder
NNencoder= selfeeg.models.ShallowNetEncoder(Chans=Chan, F=8)
# It's suggested to copy the random initialization for embedding analysis
NNencoder2= copy.deepcopy(NNencoder)
# SSL model
head_size=[ 88, 64, 64]
SelfMdl = selfeeg.ssl.SimCLR(
encoder=NNencoder, projection_head=head_size).to(device=device)
# loss (fit method has a default loss based on the SSL algorithm
loss=selfeeg.losses.simclr_loss
loss_arg={'temperature': 0.5}
# earlystopper
earlystop = selfeeg.ssl.EarlyStopping(
patience=25, min_delta=1e-05, record_best_weights=True)
# optimizer
optimizer = torch.optim.Adam(SelfMdl.parameters(), lr=1e-3)
# lr scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
Pretrain the model
Each SSL algorithm has an already implemented fit method, similar to scikitlearn or Keras. Of course it’s not complete as the fit of bigger framewoks, but it certainly save you lots of lines of code and help you monitorate the training.
[8]:
loss_info = SelfMdl.fit(
train_dataloader = trainloader,
augmenter = Augmenter,
epochs = 5,
optimizer = optimizer,
loss_func = loss,
loss_args = loss_arg,
lr_scheduler = scheduler,
EarlyStopper = earlystop,
validation_dataloader = valloader,
verbose = True,
device = device,
return_loss_info = True
)
epoch [1/5]
val 20/20: 100%|███████████████| 161/161 [ 8.59 Batch/s, train_loss=4.72287, val_loss=4.97601]
epoch [2/5]
val 20/20: 100%|███████████████| 161/161 [ 8.28 Batch/s, train_loss=4.46746, val_loss=4.84657]
epoch [3/5]
val 20/20: 100%|███████████████| 161/161 [ 1.97 Batch/s, train_loss=4.47779, val_loss=4.97520]
epoch [4/5]
val 20/20: 100%|███████████████| 161/161 [ 2.03 Batch/s, train_loss=4.34311, val_loss=4.78815]
epoch [5/5]
val 20/20: 100%|███████████████| 161/161 [ 3.59 Batch/s, train_loss=4.34365, val_loss=4.78576]
Fine-tuning Phase
Now that the encoder is pretrained, let’s perform fine-tuning. To do that, we need to recreate the right dataloaders and models. After that, the selfeeg library provides a fine-tuning function that is similar to the fit method.
Define fine-tuning dataloaders
This phase is basically the same as the previous one. The only differences are:
We are using the fine-tuning data
We are creating a test set for evaluation
We need to extract a label from each sample.
The used classes and methods are the same already used from the dataloading module
[9]:
# Extract only the samples for fine-tuning
filesFT= EEGsplit.loc[EEGsplit['split_set']==2, 'file_name'].values
EEGlenFT= EEGlen.loc[EEGlen['file_name'].isin(filesFT)]
EEGlenFT = EEGlenFT.reset_index().drop(columns=['index'])
labels = classes[ EEGsplit[EEGsplit['split_set']==2].index.tolist()]
# split the fine-tuning data in train-test-validation
EEGsplitFT = dl.get_eeg_split_table(
partition_table=EEGlenFT,
test_ratio = 0.2,
val_ratio= 0.1,
val_ratio_on_all_data=False,
stratified=True,
labels=labels,
split_tolerance=0.001,
perseverance=10000
)
# TRAINING DATALOADER
trainsetFT = dl.EEGDataset(
EEGlenFT, EEGsplitFT, [freq, window, overlap], 'train', supervised=True,
label_on_load=True, load_function=loadEEG, optional_load_fun_args=[True]
)
trainsamplerFT = dl.EEGSampler(trainsetFT, batchsize, workers)
trainloaderFT = DataLoader(
dataset = trainsetFT, batch_size= batchsize, sampler=trainsamplerFT, num_workers=workers)
# VALIDATION DATALOADER
valsetFT = dl.EEGDataset(
EEGlenFT, EEGsplitFT, [freq, window, overlap], 'validation', supervised=True,
label_on_load=True, load_function=loadEEG, optional_load_fun_args=[True]
)
valloaderFT = DataLoader(
dataset=valsetFT, batch_size=batchsize, num_workers=workers, shuffle=False)
#TEST DATALOADER
testsetFT = dl.EEGDataset(
EEGlenFT, EEGsplitFT, [freq, window, overlap], 'test', supervised=True,
label_on_load=True, load_function=loadEEG, optional_load_fun_args=[True]
)
testloaderFT = DataLoader(dataset = testsetFT, batch_size= batchsize, shuffle=False)
dl.check_split(EEGlenFT, EEGsplitFT, labels)
train ratio: 0.72
validation ratio: 0.08
test ratio: 0.20
train labels ratio: 0.0=0.620, 1.0=0.380,
val labels ratio: 0.0=0.627, 1.0=0.373,
test labels ratio: 0.0=0.617, 1.0=0.383,
Define fine-tuning model and other training objects
Remember that in this phase you need to transfer the pretrained encoder
[10]:
FinalMdl = selfeeg.models.ShallowNet(
nb_classes = 2, Chans = Chan, Samples = int(freq*window), F=8)
# Transfer the pretrained backbone and move the final model to the right device
SelfMdl.train()
SelfMdl.to(device='cpu')
FinalMdl.encoder = SelfMdl.get_encoder()
FinalMdl.train()
FinalMdl.to(device=device)
# DEFINE LOSS
def loss_fineTuning(yhat, ytrue):
ytrue = ytrue + 0.
yhat = torch.squeeze(yhat)
return F.binary_cross_entropy_with_logits(yhat, ytrue)
# DEFINE EARLYSTOPPER
earlystopFT = selfeeg.ssl.EarlyStopping(
patience=10, min_delta=1e-03, record_best_weights=True)
# DEFINE OPTIMIZER
optimizerFT = torch.optim.Adam(FinalMdl.parameters(), lr=1e-3)
schedulerFT = torch.optim.lr_scheduler.ExponentialLR(optimizerFT, gamma=0.97)
Fine-tuning
Fine-tuning can be easily performed with the fine-tuning method.
NOTE 1
it is better to first pretrain only the new head, and then update all model’s weights. However, EEGNet is small and this is a simple example, so we directly fine-tune all the network.
[11]:
finetuning_loss=selfeeg.ssl.fine_tune(
model = FinalMdl,
train_dataloader = trainloaderFT,
epochs = 10,
optimizer = optimizerFT,
loss_func = loss_fineTuning,
lr_scheduler = schedulerFT,
EarlyStopper = earlystopFT,
validation_dataloader = valloaderFT,
verbose = True,
device = device,
return_loss_info = True
)
epoch [1/10]
val 4/4: 100%|███████████████| 33/33 [101.80 Batch/s, train_loss=0.74088, val_loss=0.52754]
epoch [2/10]
val 4/4: 100%|███████████████| 33/33 [107.47 Batch/s, train_loss=0.51391, val_loss=0.42154]
epoch [3/10]
val 4/4: 100%|███████████████| 33/33 [112.56 Batch/s, train_loss=0.41865, val_loss=0.26677]
epoch [4/10]
val 4/4: 100%|███████████████| 33/33 [106.99 Batch/s, train_loss=0.26489, val_loss=0.18554]
epoch [5/10]
val 4/4: 100%|███████████████| 33/33 [113.98 Batch/s, train_loss=0.18971, val_loss=0.11007]
epoch [6/10]
val 4/4: 100%|███████████████| 33/33 [115.99 Batch/s, train_loss=0.12763, val_loss=0.07227]
epoch [7/10]
val 4/4: 100%|███████████████| 33/33 [108.40 Batch/s, train_loss=0.14895, val_loss=0.05631]
epoch [8/10]
val 4/4: 100%|███████████████| 33/33 [109.08 Batch/s, train_loss=0.09240, val_loss=0.04787]
epoch [9/10]
val 4/4: 100%|███████████████| 33/33 [114.48 Batch/s, train_loss=0.11739, val_loss=0.03897]
epoch [10/10]
val 4/4: 100%|███████████████| 33/33 [114.82 Batch/s, train_loss=0.08307, val_loss=0.03281]
Evaluate fine-tuned model
Now you can evaluate your model in whatever method you prefer. Here is a simple example with the classification report from sklearn.
[12]:
from sklearn.metrics import classification_report
nb_classes=2
FinalMdl.eval()
ytrue=torch.zeros(len(testloaderFT.dataset))
ypred=torch.zeros_like(ytrue)
cnt=0
for i, (X, Y) in enumerate(testloaderFT):
X=X.to(device=device)
ytrue[cnt:cnt+X.shape[0]]= Y
with torch.no_grad():
yhat = torch.sigmoid(FinalMdl(X)).to(device='cpu')
ypred[cnt:cnt+X.shape[0]] = torch.squeeze(yhat)
cnt += X.shape[0]
print('Results of trivial Example\n')
print(classification_report(ytrue,ypred>0.5))
Results of trivial Example
precision recall f1-score support
0.0 1.00 1.00 1.00 58
1.0 1.00 1.00 1.00 70
accuracy 1.00 128
macro avg 1.00 1.00 1.00 128
weighted avg 1.00 1.00 1.00 128