gradgraph.optimization.tf.callbacks module

class EarlyStoppingByThreshold(monitor='val_loss', threshold=0, min_delta=0, baseline=None, verbose=0, mode='auto', restore_best_weights=False, start_from_epoch=0)[source]

Bases: Callback

Early stopping callback to terminate training when a monitored metric reaches a specified threshold.

This callback is used to stop training when a monitored metric reaches a specified threshold, which can be useful to prevent overfitting or to save computational resources.

Parameters:
  • monitor (str, optional) – The metric to be monitored. Default is ‘val_loss’.

  • threshold (float, optional) – The threshold value that the monitored metric must reach to stop training. Default is 0.

  • min_delta (float, optional) – Minimum change in the monitored metric to qualify as an improvement. Default is 0.

  • baseline (float, optional) – Baseline value for the monitored metric. Training will stop if the model does not show improvement over the baseline. Default is None.

  • verbose (int, optional) – Verbosity mode. 0 = silent, 1 = progress messages. Default is 0.

  • mode ({'auto', 'min', 'max'}, optional) – Mode for determining whether the monitored metric should be minimized or maximized. Default is ‘auto’.

  • restore_best_weights (bool, optional) – Whether to restore model weights from the epoch with the best monitored metric value. Default is False.

  • start_from_epoch (int, optional) – The epoch from which to start monitoring the metric. Default is 0.

Raises:

ValueError – If the mode is not recognized or if the monitored metric cannot be automatically determined to be minimized or maximized.

Warns:

UserWarning – If the mode is unknown, or if the monitored metric is not available in the logs.

Examples

>>> early_stopping = EarlyStoppingByThreshold(monitor='val_accuracy', threshold=0.95, mode='max')
>>> model.fit(X_train, y_train, callbacks=[early_stopping])
get_monitor_value(logs)[source]

Retrieve the value of the monitored metric from the logs.

Parameters:

logs (dict) – A dictionary containing the metrics and their corresponding values. If None, an empty dictionary is used.

Returns:

monitor_value – The value of the monitored metric specified by self.monitor. If the metric is not found in logs, None is returned.

Return type:

any

Warns:

UserWarning – If the monitored metric specified by self.monitor is not found in logs, a warning is issued indicating the available metrics.

Examples

>>> class Monitor:
...     def __init__(self, monitor):
...         self.monitor = monitor
...     def get_monitor_value(self, logs):
...         logs = logs or {}
...         monitor_value = logs.get(self.monitor)
...         if monitor_value is None:
...             warnings.warn(
...                 (
...                     f"Early stopping conditioned on metric `{self.monitor}` "
...                     "which is not available. "
...                     f"Available metrics are: {','.join(list(logs.keys()))}"
...                 ),
...                 stacklevel=2,
...             )
...         return monitor_value
>>> monitor = Monitor('accuracy')
>>> logs = {'loss': 0.25, 'val_loss': 0.3}
>>> monitor.get_monitor_value(logs)
UserWarning: Early stopping conditioned on metric `accuracy` which is not available. Available metrics are: loss,val_loss
None
property model
on_batch_begin(batch, logs=None)

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch, logs=None)

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Args:

epoch: Integer, index of epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_epoch_end(epoch, logs=None)[source]

Callback function to be called at the end of each epoch to adjust the learning rate.

This function checks the monitored metric and adjusts the learning rate of the optimizer if the metric has not improved for a specified number of epochs (patience). It also handles cooldown periods and ensures the learning rate does not fall below a minimum value.

Parameters:
  • epoch (int) – The index of the current epoch.

  • logs (dict, optional) – A dictionary of logs from the current epoch. If not provided, an empty dictionary is used.

Warns:

UserWarning – If the monitored metric is not available in the logs.

Return type:

None

Notes

This function assumes that the model has an optimizer attribute and that the optimizer has a learning_rate attribute. The learning rate is reduced by a factor when the monitored metric does not improve for a specified number of epochs.

Examples

>>> # Assuming `self` is an instance of a class with the necessary attributes
>>> self.on_epoch_end(epoch=5, logs={'accuracy': 0.8})
on_predict_batch_begin(batch, logs=None)

Called at the beginning of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_predict_batch_end(batch, logs=None)

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_predict_begin(logs=None)

Called at the beginning of prediction.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_predict_end(logs=None)

Called at the end of prediction.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_batch_begin(batch, logs=None)

Called at the beginning of a batch in evaluate methods.

Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_batch_end(batch, logs=None)

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_end(logs=None)

Called at the end of evaluation or validation.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently the output of the last call to

on_test_batch_end() is passed to this argument for this method but that may change in the future.

on_train_batch_begin(batch, logs=None)

Called at the beginning of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_train_begin(logs=None)[source]

Executes actions at the beginning of the training process.

This method is typically called at the start of the training process to initialize or reset any necessary states or variables.

Parameters:

logs (dict, optional) – Currently, this parameter is not used. It is included for compatibility with similar methods that may require logging information.

Return type:

None

Notes

This method calls the _reset function to ensure that the training state is initialized properly before training begins.

on_train_end(logs=None)[source]

Handles operations to be performed at the end of training.

This method is typically used in a training loop to manage actions such as early stopping and restoring model weights to the best observed state during training.

Parameters:

logs (dict, optional) – Currently not used. Defaults to None.

Notes

  • If early stopping was triggered (i.e., self.stopped_epoch > 0), and verbosity is enabled (self.verbose > 0), a message indicating the epoch at which training was stopped is printed.

  • If self.restore_best_weights is True and self.best_weights is not None, the model’s weights are restored to the best observed state. A message is printed if verbosity is enabled.

Examples

>>> class Model:
...     def __init__(self):
...         self.stopped_epoch = 5
...         self.verbose = 1
...         self.restore_best_weights = True
...         self.best_weights = [0.1, 0.2, 0.3]
...         self.best_epoch = 3
...         self.model = self
...     def set_weights(self, weights):
...         print("Weights set to:", weights)
...     def on_train_end(self, logs=None):
...         if self.stopped_epoch > 0 and self.verbose > 0:
...             print(f"Epoch {self.stopped_epoch + 1}: early stopping")
...         if self.restore_best_weights and self.best_weights is not None:
...             if self.verbose > 0:
...                 print("Restoring model weights from the end of the best epoch:", f"{self.best_epoch + 1}.")
...             self.set_weights(self.best_weights)
>>> model = Model()
>>> model.on_train_end()
Epoch 6: early stopping
Restoring model weights from the end of the best epoch: 4.
Weights set to: [0.1, 0.2, 0.3]
set_model(model)
set_params(params)
class ReduceLROnPlateau(optimizer='optimizer', monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0.0, **kwargs)[source]

Bases: Callback

Reduce learning rate when a metric has stopped improving.

This callback is used to reduce the learning rate by a specified factor when a monitored metric has stopped improving. It helps in fine-tuning the learning process by decreasing the learning rate when the model reaches a plateau in its performance.

Parameters:
  • optimizer (str, optional) – The optimizer attribute of the model to adjust the learning rate for. Default is ‘optimizer’.

  • monitor (str, optional) – The metric to be monitored. Default is ‘val_loss’.

  • factor (float, optional) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Must be less than 1.0. Default is 0.1.

  • patience (int, optional) – Number of epochs with no improvement after which learning rate will be reduced. Default is 10.

  • verbose (int, optional) – Verbosity mode. 0 = silent, 1 = update messages. Default is 0.

  • mode ({'auto', 'min', 'max'}, optional) – Mode for reducing the learning rate. In ‘min’ mode, the learning rate will be reduced when the monitored quantity has stopped decreasing; in ‘max’ mode it will be reduced when the monitored quantity has stopped increasing; in ‘auto’ mode, the direction is automatically inferred from the name of the monitored quantity. Default is ‘auto’.

  • min_delta (float, optional) – Threshold for measuring the new optimum, to only focus on significant changes. Default is 1e-4.

  • cooldown (int, optional) – Number of epochs to wait before resuming normal operation after the learning rate has been reduced. Default is 0.

  • min_lr (float, optional) – Lower bound on the learning rate. Default is 0.0.

Raises:

ValueError – If factor is greater than or equal to 1.0.

Warns:

UserWarning – If the mode is unknown, it falls back to ‘auto’ mode. If the monitor metric is not available in the logs.

Examples

>>> reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
>>> model.fit(x_train, y_train, callbacks=[reduce_lr])
in_cooldown()[source]

Check if the cooldown period is active.

Returns:

True if the cooldown counter is greater than zero, indicating that the cooldown period is still active. False otherwise.

Return type:

bool

Examples

>>> obj = ReduceLROnPlateau()
>>> obj.cooldown_counter = 5
>>> obj.in_cooldown()
True
>>> obj.cooldown_counter = 0
>>> obj.in_cooldown()
False
property model
on_batch_begin(batch, logs=None)

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch, logs=None)

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Args:

epoch: Integer, index of epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_epoch_end(epoch, logs=None)[source]

Called at the end of each epoch to monitor and adjust the learning rate.

If the monitored metric has not improved for a number of epochs defined by patience, and the cooldown period has elapsed, the learning rate is reduced by the specified factor, but never below min_lr.

Parameters:
  • epoch (int) – Index of the current epoch.

  • logs (dict, optional) – Dictionary of metrics from the current epoch, including the monitored metric and learning rate. If the monitored metric is not in logs, a warning is issued.

Warns:

UserWarning – If the monitor metric is not found in logs.

Notes

The new learning rate is assigned directly to the optimizer attribute specified during initialization. If verbose > 0, a message is printed when the learning rate is reduced.

on_predict_batch_begin(batch, logs=None)

Called at the beginning of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_predict_batch_end(batch, logs=None)

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_predict_begin(logs=None)

Called at the beginning of prediction.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_predict_end(logs=None)

Called at the end of prediction.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_batch_begin(batch, logs=None)

Called at the beginning of a batch in evaluate methods.

Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_batch_end(batch, logs=None)

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_test_end(logs=None)

Called at the end of evaluation or validation.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently the output of the last call to

on_test_batch_end() is passed to this argument for this method but that may change in the future.

on_train_batch_begin(batch, logs=None)

Called at the beginning of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Currently no data is passed to this argument for this

method but that may change in the future.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_train_begin(logs=None)[source]

Resets the internal state at the beginning of training.

This method is called at the start of training to initialize or reset internal counters such as cooldown and wait, and to prepare for tracking the monitored metric.

Parameters:

logs (dict, optional) – Currently unused. Reserved for future use or compatibility with the Keras callback API. Defaults to None.

on_train_end(logs=None)

Called at the end of training.

Subclasses should override for any actions to run.

Args:
logs: Dict. Currently the output of the last call to

on_epoch_end() is passed to this argument for this method but that may change in the future.

set_model(model)
set_params(params)