Source code for bird_cloud_gnn.early_stopper

"""Module for early stopping class
"""
import numpy as np


[docs] class EarlyStopper: """Early stopper check. This class is used to stop the training process if the validation loss starts increasing. The validation loss is considered to be increasing if it is greater than the minimum validation loss found so far plus an absolute and/or relative tolerance. The class keeps track of the minimum validation loss found so far and the number of consecutive iterations where the validation loss has increased. If the number of consecutive iterations where the validation loss has increased exceeds a certain threshold, the training process is stopped. Attributes: patience (int): How many consecutive iterations to wait before stopping. min_abs_delta (float): Absolute tolerance to the increase. min_rel_delta (float): Relative tolerance to the increase. counter (int): Number of consecutive iterations where the validation loss has increased. min_validation_loss (float): Minimum validation loss found so far. Methods: __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0): Initializes the EarlyStopper object. early_stop(self, validation_loss): Checks whether it is time to stop, and updates the internal state of the EarlyStopper object. Example usage: early_stopper = EarlyStopper(patience=5, min_abs_delta=0.1, min_rel_delta=0.01) for epoch in range(num_epochs): train_loss = train(model, train_loader) val_loss = validate(model, val_loader) if early_stopper.early_stop(val_loss): print(f"Validation loss has been increasing for {early_stopper.patience} consecutive epochs. " f"Training stopped.") break """
[docs] def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0): """Initializes the EarlyStopper object. Args: patience (int, optional): How many consecutive iterations to wait before stopping. Defaults to 3. min_abs_delta (float, optional): Absolute tolerance to the increase. Defaults to 1e-2. min_rel_delta (float, optional): Relative tolerance to the increase. Defaults to 0.0. """ self.patience = patience self.min_abs_delta = min_abs_delta self.min_rel_delta = min_rel_delta self.counter = 0 self.min_validation_loss = np.inf
[docs] def early_stop(self, validation_loss): """Checks whether it is time to stop, and updates the internal state of the EarlyStopper object. Args: validation_loss (float): Current validation loss Returns: stop (boolean): Whether it is time to stop (True) or not (False). """ if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 return False if self.min_validation_loss is np.inf: return False loss_threshold = ( self.min_validation_loss + self.min_abs_delta + self.min_rel_delta * np.abs(self.min_validation_loss) ) if validation_loss > loss_threshold: self.counter += 1 if self.counter >= self.patience: return True return False