Source code for gradgraph.optimization.tf.callbacks

#!/usr/bin/env python3
# 
# callbacks.py
# 
# Created by Nicolas Fricker on 08/31/2025.
# Copyright © 2025 Nicolas Fricker. All rights reserved.
# 

import warnings
import numpy as np
import tensorflow as tf

[docs] class EarlyStoppingByThreshold(tf.keras.callbacks.Callback): """ Early stopping callback to terminate training when a monitored metric reaches a specified threshold. This callback is used to stop training when a monitored metric reaches a specified threshold, which can be useful to prevent overfitting or to save computational resources. Parameters ---------- monitor : str, optional The metric to be monitored. Default is 'val_loss'. threshold : float, optional The threshold value that the monitored metric must reach to stop training. Default is 0. min_delta : float, optional Minimum change in the monitored metric to qualify as an improvement. Default is 0. baseline : float, optional Baseline value for the monitored metric. Training will stop if the model does not show improvement over the baseline. Default is None. verbose : int, optional Verbosity mode. 0 = silent, 1 = progress messages. Default is 0. mode : {'auto', 'min', 'max'}, optional Mode for determining whether the monitored metric should be minimized or maximized. Default is 'auto'. restore_best_weights : bool, optional Whether to restore model weights from the epoch with the best monitored metric value. Default is False. start_from_epoch : int, optional The epoch from which to start monitoring the metric. Default is 0. Raises ------ ValueError If the mode is not recognized or if the monitored metric cannot be automatically determined to be minimized or maximized. Warns ----- UserWarning If the mode is unknown, or if the monitored metric is not available in the logs. Examples -------- >>> early_stopping = EarlyStoppingByThreshold(monitor='val_accuracy', threshold=0.95, mode='max') >>> model.fit(X_train, y_train, callbacks=[early_stopping]) """ def __init__( self, monitor="val_loss", threshold=0, min_delta=0, baseline=None, verbose=0, mode="auto", restore_best_weights=False, start_from_epoch=0, ) -> None: """ Initialize the EarlyStoppingByThreshold callback. This callback stops training once a monitored metric reaches a specified threshold. It supports both minimizing and maximizing metrics and can optionally restore the model weights from the best epoch. Parameters ---------- monitor : str, optional Name of the metric to monitor (e.g., 'val_loss', 'val_accuracy'). Defaults to 'val_loss'. threshold : float, optional The threshold value that the monitored metric must reach to trigger stopping. For 'min' mode, training stops when the metric is less than or equal to this value. For 'max' mode, when it is greater than or equal to this value. Defaults to 0. min_delta : float, optional Minimum change in the monitored metric to qualify as an improvement. Used to prevent stopping for negligible changes. Defaults to 0. baseline : float or None, optional Baseline value for the monitored metric. Training will stop if the model does not show improvement over the baseline. Defaults to None. verbose : int, optional Verbosity mode. 0 = silent, 1 = prints stopping messages. Defaults to 0. mode : {'auto', 'min', 'max'}, optional Mode for interpreting the monitored metric. - 'min': training stops when the monitored metric has stopped decreasing. - 'max': training stops when the monitored metric has stopped increasing. - 'auto': automatically infers direction from the name of the monitored metric. Defaults to 'auto'. restore_best_weights : bool, optional If True, restores model weights from the epoch with the best monitored value. Defaults to False. start_from_epoch : int, optional Epoch number from which to start monitoring. Defaults to 0. Raises ------ ValueError If `mode` is not one of {'auto', 'min', 'max'}. Warns ----- UserWarning If `mode` is unrecognized, it falls back to 'auto'. """ super().__init__() self.monitor = monitor self.threshold = threshold self.verbose = verbose self.baseline = baseline self.min_delta = abs(min_delta) self.stopped_epoch = 0 self.restore_best_weights = restore_best_weights self.best_weights = None self.best_epoch = 0 self.start_from_epoch = start_from_epoch if mode not in ["auto", "min", "max"]: warnings.warn( f"EarlyStopping mode {mode} is unknown, fallback to auto mode.", stacklevel=2, ) mode = "auto" self.mode = mode self.monitor_op = None def _set_monitor_op(self): """ Set the monitoring operation for early stopping based on the specified mode. This method determines the appropriate TensorFlow comparison operation (`tf.math.less` or `tf.math.greater`) to use for monitoring a specified metric during training. The operation is set based on the `mode` attribute or inferred from the metric's characteristics. Raises ------ ValueError If the `monitor` attribute is set to a metric that cannot be automatically determined to be maximized or minimized. Notes ----- - If `mode` is "min", the monitoring operation is set to `tf.math.less`. - If `mode` is "max", the monitoring operation is set to `tf.math.greater`. - If the metric name is "loss", the monitoring operation defaults to `tf.math.less`. - If the metric has a `_direction` attribute, it is used to determine the monitoring operation. - The `min_delta` attribute is negated if the monitoring operation is `tf.math.less`. - The `best` attribute is initialized to positive or negative infinity based on the monitoring operation. """ if self.mode == "min": self.monitor_op = tf.math.less elif self.mode == "max": self.monitor_op = tf.math.greater else: metric_name = self.monitor.removeprefix("val_") if metric_name == "loss": self.monitor_op = tf.math.less if hasattr(self.model, "metrics"): for m in self.model.metrics: if m.name == metric_name: if hasattr(m, "_direction"): if m._direction == "up": self.monitor_op = tf.math.greater else: self.monitor_op = tf.math.less if self.monitor_op is None: raise ValueError( f"EarlyStopping callback received monitor={self.monitor} " "but Keras isn't able to automatically determine whether " "that metric should be maximized or minimized. " "Pass `mode='max'` in order to do early stopping based " "on the highest metric value, or pass `mode='min'` " "in order to use the lowest value." ) if self.monitor_op == tf.math.less: self.min_delta *= -1 self.best = ( float("inf") if self.monitor_op == tf.math.less else -float("inf") )
[docs] def on_train_begin(self, logs = None) -> None: """ Executes actions at the beginning of the training process. This method is typically called at the start of the training process to initialize or reset any necessary states or variables. Parameters ---------- logs : dict, optional Currently, this parameter is not used. It is included for compatibility with similar methods that may require logging information. Notes ----- This method calls the `_reset` function to ensure that the training state is initialized properly before training begins. """ self.stopped_epoch = 0 self.best_weights = None self.best_epoch = 0
[docs] def on_epoch_end(self, epoch: int, logs = None) -> None: """ Callback function to be called at the end of each epoch to adjust the learning rate. This function checks the monitored metric and adjusts the learning rate of the optimizer if the metric has not improved for a specified number of epochs (`patience`). It also handles cooldown periods and ensures the learning rate does not fall below a minimum value. Parameters ---------- epoch : int The index of the current epoch. logs : dict, optional A dictionary of logs from the current epoch. If not provided, an empty dictionary is used. Warns ----- UserWarning If the monitored metric is not available in the logs. Notes ----- This function assumes that the model has an optimizer attribute and that the optimizer has a `learning_rate` attribute. The learning rate is reduced by a factor when the monitored metric does not improve for a specified number of epochs. Examples -------- >>> # Assuming `self` is an instance of a class with the necessary attributes >>> self.on_epoch_end(epoch=5, logs={'accuracy': 0.8}) """ if self.monitor_op is None: self._set_monitor_op() current = self.get_monitor_value(logs) if current is None or epoch < self.start_from_epoch: return if self.restore_best_weights and self.best_weights is None: self.best_weights = self.model.get_weights() self.best_epoch = epoch if self._is_improvement(current, self.best): self.best = current self.best_epoch = epoch if self.restore_best_weights: self.best_weights = self.model.get_weights() if self.monitor_op(current, self.threshold) and epoch > 0: self.stopped_epoch = epoch self.model.stop_training = True return
[docs] def on_train_end(self, logs=None): """ Handles operations to be performed at the end of training. This method is typically used in a training loop to manage actions such as early stopping and restoring model weights to the best observed state during training. Parameters ---------- logs : dict, optional Currently not used. Defaults to `None`. Notes ----- - If early stopping was triggered (i.e., `self.stopped_epoch > 0`), and verbosity is enabled (`self.verbose > 0`), a message indicating the epoch at which training was stopped is printed. - If `self.restore_best_weights` is `True` and `self.best_weights` is not `None`, the model's weights are restored to the best observed state. A message is printed if verbosity is enabled. Examples -------- >>> class Model: ... def __init__(self): ... self.stopped_epoch = 5 ... self.verbose = 1 ... self.restore_best_weights = True ... self.best_weights = [0.1, 0.2, 0.3] ... self.best_epoch = 3 ... self.model = self ... def set_weights(self, weights): ... print("Weights set to:", weights) ... def on_train_end(self, logs=None): ... if self.stopped_epoch > 0 and self.verbose > 0: ... print(f"Epoch {self.stopped_epoch + 1}: early stopping") ... if self.restore_best_weights and self.best_weights is not None: ... if self.verbose > 0: ... print("Restoring model weights from the end of the best epoch:", f"{self.best_epoch + 1}.") ... self.set_weights(self.best_weights) >>> model = Model() >>> model.on_train_end() Epoch 6: early stopping Restoring model weights from the end of the best epoch: 4. Weights set to: [0.1, 0.2, 0.3] """ if self.stopped_epoch > 0 and self.verbose > 0: tf.print( f"Epoch {self.stopped_epoch + 1}: early stopping" ) if self.restore_best_weights and self.best_weights is not None: if self.verbose > 0: tf.print( "Restoring model weights from " "the end of the best epoch: " f"{self.best_epoch + 1}." ) self.model.set_weights(self.best_weights)
[docs] def get_monitor_value(self, logs): """ Retrieve the value of the monitored metric from the logs. Parameters ---------- logs : dict A dictionary containing the metrics and their corresponding values. If `None`, an empty dictionary is used. Returns ------- monitor_value : any The value of the monitored metric specified by `self.monitor`. If the metric is not found in `logs`, `None` is returned. Warns ----- UserWarning If the monitored metric specified by `self.monitor` is not found in `logs`, a warning is issued indicating the available metrics. Examples -------- >>> class Monitor: ... def __init__(self, monitor): ... self.monitor = monitor ... def get_monitor_value(self, logs): ... logs = logs or {} ... monitor_value = logs.get(self.monitor) ... if monitor_value is None: ... warnings.warn( ... ( ... f"Early stopping conditioned on metric `{self.monitor}` " ... "which is not available. " ... f"Available metrics are: {','.join(list(logs.keys()))}" ... ), ... stacklevel=2, ... ) ... return monitor_value >>> monitor = Monitor('accuracy') >>> logs = {'loss': 0.25, 'val_loss': 0.3} >>> monitor.get_monitor_value(logs) UserWarning: Early stopping conditioned on metric `accuracy` which is not available. Available metrics are: loss,val_loss None """ logs = logs or {} monitor_value = logs.get(self.monitor) if monitor_value is None: warnings.warn( ( f"Early stopping conditioned on metric `{self.monitor}` " "which is not available. " f"Available metrics are: {','.join(list(logs.keys()))}" ), stacklevel=2, ) return monitor_value
def _is_improvement(self, monitor_value, reference_value): """ Determine if the monitored value shows improvement over the reference value. This method evaluates whether the `monitor_value` indicates an improvement when compared to the `reference_value`. The improvement is assessed using the `monitor_op` operation, considering a minimum change threshold defined by `min_delta`. Parameters ---------- monitor_value : float The current value of the metric being monitored. reference_value : float The reference value to compare against, typically the best recorded value so far. Returns ------- bool True if the `monitor_value` shows improvement over the `reference_value` according to the `monitor_op` operation and `min_delta` threshold; False otherwise. Notes ----- The `monitor_op` is a callable that determines the comparison logic, such as whether an increase or decrease in the `monitor_value` is considered an improvement. The `min_delta` is a small positive number that defines the minimum change required to qualify as an improvement. """ return self.monitor_op(monitor_value - self.min_delta, reference_value)
[docs] class ReduceLROnPlateau(tf.keras.callbacks.Callback): """ Reduce learning rate when a metric has stopped improving. This callback is used to reduce the learning rate by a specified factor when a monitored metric has stopped improving. It helps in fine-tuning the learning process by decreasing the learning rate when the model reaches a plateau in its performance. Parameters ---------- optimizer : str, optional The optimizer attribute of the model to adjust the learning rate for. Default is 'optimizer'. monitor : str, optional The metric to be monitored. Default is 'val_loss'. factor : float, optional Factor by which the learning rate will be reduced. `new_lr = lr * factor`. Must be less than 1.0. Default is 0.1. patience : int, optional Number of epochs with no improvement after which learning rate will be reduced. Default is 10. verbose : int, optional Verbosity mode. 0 = silent, 1 = update messages. Default is 0. mode : {'auto', 'min', 'max'}, optional Mode for reducing the learning rate. In 'min' mode, the learning rate will be reduced when the monitored quantity has stopped decreasing; in 'max' mode it will be reduced when the monitored quantity has stopped increasing; in 'auto' mode, the direction is automatically inferred from the name of the monitored quantity. Default is 'auto'. min_delta : float, optional Threshold for measuring the new optimum, to only focus on significant changes. Default is 1e-4. cooldown : int, optional Number of epochs to wait before resuming normal operation after the learning rate has been reduced. Default is 0. min_lr : float, optional Lower bound on the learning rate. Default is 0.0. Raises ------ ValueError If `factor` is greater than or equal to 1.0. Warns ----- UserWarning If the `mode` is unknown, it falls back to 'auto' mode. If the `monitor` metric is not available in the logs. Examples -------- >>> reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001) >>> model.fit(x_train, y_train, callbacks=[reduce_lr]) """ def __init__( self, optimizer='optimizer', monitor="val_loss", factor=0.1, patience=10, verbose=0, mode="auto", min_delta=1e-4, cooldown=0, min_lr=0.0, **kwargs, ): """ Initialize the ReduceLROnPlateau scheduler. Parameters ---------- optimizer : str, optional The optimizer to be used. Default is 'optimizer'. monitor : str, optional The metric to be monitored. Default is 'val_loss'. factor : float, optional Factor by which the learning rate will be reduced. `new_lr = lr * factor`. Must be less than 1.0. Default is 0.1. patience : int, optional Number of epochs with no improvement after which learning rate will be reduced. Default is 10. verbose : int, optional Verbosity mode. Default is 0. mode : {'auto', 'min', 'max'}, optional One of `{'auto', 'min', 'max'}`. In 'min' mode, the learning rate will be reduced when the quantity monitored has stopped decreasing; in 'max' mode it will be reduced when the quantity monitored has stopped increasing; in 'auto' mode, the direction is automatically inferred from the name of the monitored quantity. Default is 'auto'. min_delta : float, optional Threshold for measuring the new optimum, to only focus on significant changes. Default is 1e-4. cooldown : int, optional Number of epochs to wait before resuming normal operation after lr has been reduced. Default is 0. min_lr : float, optional Lower bound on the learning rate. Default is 0.0. **kwargs Additional arguments passed to the superclass initializer. Raises ------ ValueError If `factor` is greater than or equal to 1.0. Notes ----- This class is typically used to reduce the learning rate when a metric has stopped improving. The reduction is multiplicative, and the learning rate is reduced by `factor` when no improvement is seen for a `patience` number of epochs. Examples -------- >>> scheduler = ReduceLROnPlateau(optimizer='adam', monitor='val_accuracy', factor=0.5, patience=5) """ super().__init__(**kwargs) self.optimizer = optimizer self.monitor = monitor if factor >= 1.0: raise ValueError( "ReduceLROnPlateau does not support a factor >= 1.0. " f"Received factor={factor}" ) self.factor = factor self.min_lr = min_lr self.min_delta = min_delta self.patience = patience self.verbose = verbose self.cooldown = cooldown self.cooldown_counter = 0 # Cooldown counter. self.wait = 0 self.best = 0 self.mode = mode self.monitor_op = None self._reset() def _reset(self): """ Resets the wait counter and cooldown counter for learning rate adjustment. This method reinitializes the internal state used for tracking the performance of a monitored metric and adjusts the mode of operation if necessary. Warns ----- UserWarning If the learning rate reduction mode is unknown, a warning is issued and the mode is set to 'auto'. Notes ----- The method sets the `monitor_op` and `best` attributes based on the current mode and monitored metric. If the mode is 'min' or 'auto' with a non-accuracy metric, the `monitor_op` is set to detect decreases in the monitored value. Otherwise, it is set to detect increases. The `cooldown_counter` and `wait` attributes are reset to zero. """ if self.mode not in {"auto", "min", "max"}: warnings.warn( f"Learning rate reduction mode {self.mode} is unknown, " "fallback to auto mode.", stacklevel=2, ) self.mode = "auto" if self.mode == "min" or ( self.mode == "auto" and "acc" not in self.monitor ): self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) self.best = np.inf else: self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) self.best = -np.inf self.cooldown_counter = 0 self.wait = 0
[docs] def on_train_begin(self, logs=None): """ Resets the internal state at the beginning of training. This method is called at the start of training to initialize or reset internal counters such as cooldown and wait, and to prepare for tracking the monitored metric. Parameters ---------- logs : dict, optional Currently unused. Reserved for future use or compatibility with the Keras callback API. Defaults to None. """ self._reset()
[docs] def on_epoch_end(self, epoch, logs=None): """ Called at the end of each epoch to monitor and adjust the learning rate. If the monitored metric has not improved for a number of epochs defined by `patience`, and the cooldown period has elapsed, the learning rate is reduced by the specified `factor`, but never below `min_lr`. Parameters ---------- epoch : int Index of the current epoch. logs : dict, optional Dictionary of metrics from the current epoch, including the monitored metric and learning rate. If the monitored metric is not in `logs`, a warning is issued. Warns ----- UserWarning If the `monitor` metric is not found in `logs`. Notes ----- The new learning rate is assigned directly to the optimizer attribute specified during initialization. If `verbose > 0`, a message is printed when the learning rate is reduced. """ logs = logs or {} if not hasattr(self.model, self.optimizer): return logs["learning_rate"] = float( tf.keras.ops.convert_to_numpy(getattr(self.model, self.optimizer).learning_rate) ) current = logs.get(self.monitor) if current is None: warnings.warn( "Learning rate reduction is conditioned on metric " f"`{self.monitor}` which is not available. Available metrics " f"are: {','.join(list(logs.keys()))}.", stacklevel=2, ) else: if self.in_cooldown(): self.cooldown_counter -= 1 self.wait = 0 if self.monitor_op(current, self.best): self.best = current self.wait = 0 elif not self.in_cooldown(): self.wait += 1 if self.wait >= self.patience: old_lr = float( tf.keras.ops.convert_to_numpy( getattr(self.model, self.optimizer).learning_rate ) ) if old_lr > np.float32(self.min_lr): new_lr = old_lr * self.factor new_lr = max(new_lr, self.min_lr) getattr(self.model, self.optimizer).learning_rate = new_lr if self.verbose > 0: tf.print( f"\nEpoch {epoch + 1}: " "ReduceLROnPlateau reducing " f"{self.optimizer} learning rate to {new_lr}." ) self.cooldown_counter = self.cooldown self.wait = 0
[docs] def in_cooldown(self): """ Check if the cooldown period is active. Returns ------- bool True if the cooldown counter is greater than zero, indicating that the cooldown period is still active. False otherwise. Examples -------- >>> obj = ReduceLROnPlateau() >>> obj.cooldown_counter = 5 >>> obj.in_cooldown() True >>> obj.cooldown_counter = 0 >>> obj.in_cooldown() False """ return self.cooldown_counter > 0