EarlyStopping

class selfeeg.ssl.base.EarlyStopping(patience: int = 5, min_delta: float = 1e-09, improvement: str = 'decrease', monitored: str = 'validation', record_best_weights: bool = True, device: str = None)[source]

Pytorch implementation of an early stopper.

It can monitor the validation or the training loss (no other metrics are currently supported).

Some arguments are similar to Keras EarlyStopping class [early] . If you want to use other implemented functionalities take a look at PyTorch Ignite [ign] .

Parameters:
  • patience (int, optional) –

    The number of epochs to wait before stopping the training. Can be any positive integer.

    Default = 5

  • min_delta (float, optional) –

    The minimum difference between the current best loss and the calculated one to consider as an improvement.

    Default = 1e-9

  • improvement (str, optional) –

    Whether to consider an increase or decrease in the best loss as an improvement. Accepted strings are:

    • [‘d’,’dec’,’decrease’] for decrease

    • [‘i’,’inc’,’increase’] for increase

    Default = “decrease”

  • monitored (str, optional) –

    Whether to monitor the training or validation loss. This attribute is used in the fine_tuning function or others class fit methods to check which calculated loss must be given. Accepted values are “train” or “validation”.

    Default = “validation”

  • record_best_weights (bool, optional) –

    Whether to record the best weights after every new best loss is reached or not. It will be used to restore such weights if the training is stopped.

    Default = True

  • device (torch.device or str, optional) –

    The device to use for model record. If given as a string, it will

    be converted in a torch.device instance. If not given, ‘cpu’ device will be used as default.

    Default = None

Note

Like in KERAS the early stopper will not automatically restore the best weights if the training ends, i.e., you reach the maximum number of epochs. To get the best weights simply call the restore_best_weights( model ) method.

Example

>>> import torch, pickle, selfeeg.losses
>>> import selfeeg.dataloading as dl
>>> import selfeeg.models
>>> utils.create_dataset()
>>> def loadEEG(path, return_label=False):
...     with open(path, 'rb') as handle:
...         EEG = pickle.load(handle)
...     x , y= EEG['data'], EEG['label']
...     return (x, y) if return_label else x
>>> def loss_fineTuning(yhat, ytrue):
...     return F.binary_cross_entropy_with_logits(torch.squeeze(yhat), ytrue + 0.)
>>> EEGlen = dl.get_eeg_partition_number('Simulated_EEG',128, 2, 0.3, load_function=loadEEG)
>>> EEGsplit = dl.get_eeg_split_table (EEGlen, seed=1234)
>>> ratios = dl.check_split(EEGlen,EEGsplit, return_ratio=True)
>>> TrainSet = dl.EEGDataset(EEGlen,EEGsplit, [128,2,0.3], 'train', True, loadEEG,
...                              optional_load_fun_args=[True], label_on_load=True)
>>> TrainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=32)
>>> shanet= models.ShallowNet(2, 8, 256)
>>> Stopper = ssl.EarlyStopping( patience=1, monitored= 'train' )
>>> Stopper.rec_best_weights(shanet) # little hack to force early stop correctly
>>> Stopper.best_loss = 0 # little hack to force early stop correctly
>>> loss_info = ssl.fine_tune(shanet, TrainLoader, 2, EarlyStopper=Stopper,
...                           loss_func=loss_fineTuning)
>>> # it should stop training and print "no improvement after 1 epochs. Training stopped."

References

early_stop(loss, count_add=1)[source]

update the counter and the best loss.

Parameters:
  • loss (float) – The calculated loss.

  • count_add (int, optional) – The number to add to the counter. It can be useful when early stopping checks are not performed after each epoch.

rec_best_weights(model)[source]

Record model’s best weights. The copy of the model is sent to the device set during EarlyStopping’s initialization (default is cpu). Original model will retain its device.

Parameters:

model (nn.Module) – The model to record.

reset_counter()[source]

Reset the counter and early stopping flag. It might be useful if you want to further train your model after the first training is stopped (maybe with a lower learning rate).

restore_best_weights(model)[source]

Restore model’s best weights.

Parameters:

model (nn.Module) – The model to restore.

Warning

Before restoring its best weights, the model is moved to the device set during EarlyStopping’s initialization. Remember to move it again to the desired device if EarlyStop’s one is not the same.