Module jidenn.model_builders.callbacks
Module containing the custom callbacks and a function that returns a list of callbacks to be used during the training process.
Expand source code
"""
Module containing the custom callbacks and a function that returns
a list of callbacks to be used during the training process.
"""
import tensorflow as tf
import os
from logging import Logger
from typing import List, Optional, Literal
from tensorflow.python.keras.utils import tf_utils
import tensorflow as tf
import shutil
import logging
from datetime import datetime
log = logging.getLogger(__name__)
class LogCallback(tf.keras.callbacks.Callback):
"""
Callback to log the training progress to a specified logging.Logger object. Logs the total time elapsed during the epoch (ETA).
Args:
epochs (int): The total number of epochs.
logger (logging.Logger): The logging object to use.
Example:
An example output of the logger:
```
Epoch 1/5: ETA: 01:18:43 - loss: 0.5569 - accuracy: 0.7089 - val_loss: 0.533 - val_accuracy: 0.7307
Epoch 2/5: ETA: 01:15:50 - loss: 0.5333 - accuracy: 0.7289 - val_loss: 0.5291 - val_accuracy: 0.7328
Epoch 3/5: ETA: 01:15:12 - loss: 0.5304 - accuracy: 0.7311 - val_loss: 0.5266 - val_accuracy: 0.7344
Epoch 4/5: ETA: 01:15:18 - loss: 0.5281 - accuracy: 0.7328 - val_loss: 0.525 - val_accuracy: 0.7351
Epoch 5/5: ETA: 01:14:56 - loss: 0.5262 - accuracy: 0.7342 - val_loss: 0.5231 - val_accuracy: 0.7374
```
"""
def __init__(self, epochs: int, logger: Logger) -> None:
self.epoch_start_time = None
self.specified_logger = logger
self.total_epochs = epochs
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None) -> None:
"""
Called at the beginning of every epoch. Sets the epoch start time.
Args:
epoch (int): Index of epoch.
logs (dict, optional): Dictionary of logs for the current epoch.
Returns:
None
"""
self.epoch_start_time = datetime.utcnow()
def on_epoch_end(self, epoch: int, logs: dict) -> None:
"""
Called at the end of every epoch. Logs the epoch progress and calculates the total time elapsed during the epoch (ETA).
Args:
epoch (int): Index of epoch.
logs (dict): Dictionary of logs for the current epoch.
Returns:
None
"""
eta = datetime.utcnow() - self.epoch_start_time if self.epoch_start_time is not None else '-:--:--'
eta = datetime.strptime(str(eta), "%H:%M:%S.%f")
log_str = f"Epoch {epoch+1}/{self.total_epochs}: "
log_str += f"ETA: {eta:%T} - "
log_str += " - ".join([f"{k}: {v:.4}" for k, v in logs.items()])
self.specified_logger.info(log_str)
class BestNModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Custom ModelCheckpoint Callback that saves the best N checkpoints based on a specified monitor metric.
Original source: https://github.com/schustmi/tf_utils/blob/915fe5e231ca302b28cd02dc8ac2e4c772a62e0b/tf_utils/callbacks.py#L34
This callback is a modification of the `ModelCheckpoint` callback in TensorFlow Keras, allowing for keeping
the best N checkpoints based on a specified monitor metric. The checkpoints are saved in a directory specified
by `filepath` argument, and the N best checkpoints are kept based on the value of the monitor metric. Optionally,
the most recent checkpoint can also be kept.
Args:
filepath (str): path where the checkpoints will be saved. Make sure
to pass a format string as otherwise the checkpoint will be overridden each step.
(see `tf.keras.callbacks.ModelCheckpoint` for detailed formatting options)
max_to_keep (int): Maximum number of best checkpoints to keep.
keep_most_recent (bool): if True, the most recent checkpoint will be saved in addition to
the best `max_to_keep` ones.
monitor (str): Name of the metric to monitor. The value for this metric will be used to
decide which checkpoints will be kept. See `keras.callbacks.ModelCheckpoint` for
more information.
mode (str): Depending on `mode`, the checkpoints with the highest ('max')
or lowest ('min') values in the monitored quantity will be kept.
**kwargs: Additional keyword arguments to pass to the parent class `tf.keras.callbacks.ModelCheckpoint`.
"""
def __init__(self,
filepath: str,
max_to_keep: int,
keep_most_recent: bool = True,
monitor: str = 'val_loss',
mode: Literal['min', 'max'] = 'min',
**kwargs):
if kwargs.pop('save_best_only', None):
log.warning("Setting `save_best_only` to False.")
if max_to_keep < 1:
log.warning("BestNModelCheckpoint parameter `max_to_keep` must be greater than 0, setting it to 1.")
max_to_keep = 1
super().__init__(filepath,
monitor=monitor,
mode=mode,
save_best_only=False,
**kwargs)
self._keep_count = max_to_keep
self._checkpoints = {}
self._keep_most_recent = keep_most_recent
if self._keep_most_recent:
self._most_recent_checkpoint = None
def _save_model(self, epoch, batch, logs):
super()._save_model(epoch, batch, logs)
logs = tf_utils.sync_to_numpy_or_python_type(logs or {})
filepath = self._get_file_path(epoch, batch, logs)
if not self._checkpoint_exists(filepath):
# Did not save a checkpoint for current epoch
return
value = logs.get(self.monitor)
if self._keep_most_recent:
# delay adding to list of current checkpoints until next save
# if we should always keep the most recent checkpoint
if self._most_recent_checkpoint:
self._checkpoints.update(self._most_recent_checkpoint)
self._most_recent_checkpoint = {filepath: value}
else:
self._checkpoints[filepath] = value
if len(self._checkpoints) > self._keep_count:
self._delete_worst_checkpoint()
def _delete_worst_checkpoint(self):
worst_checkpoint = None # tuple (filepath, value)
for checkpoint in self._checkpoints.items():
if worst_checkpoint is None or self.monitor_op(worst_checkpoint[1], checkpoint[1]):
worst_checkpoint = checkpoint
self._checkpoints.pop(worst_checkpoint[0])
self._delete_checkpoint_files(worst_checkpoint[0])
@staticmethod
def _delete_checkpoint_files(checkpoint_path):
log.info(f"Removing files for checkpoint '{checkpoint_path}'")
if os.path.isdir(checkpoint_path):
# SavedModel format delete the whole directory
shutil.rmtree(checkpoint_path)
return
for f in tf.io.gfile.glob(checkpoint_path + '*'):
os.remove(f)
def get_callbacks(base_logdir: str,
epochs: int,
log: Logger,
checkpoint: Optional[str] = 'checkpoints',
backup: Optional[str] = 'backup',
backup_freq: Optional[int] = None) -> List[tf.keras.callbacks.Callback]:
"""
Returns a list of Keras callbacks for a training session.
Args:
base_log_dir (str): The base directory to use for saving logs, checkpoints, and backups.
epochs (int): The number of epochs to train for.
log (logging.Logger): A logger object to use for logging information during training.
checkpoint (str): The directory to use for saving model checkpoints. If None, no checkpoints will be saved.
backup (str): The directory to use for saving backups of the training session. If None, no backups will be saved.
backup_freq (int, optional): The frequency (in batches) at which to save backups of the training session. If None, backups will only be saved at the end of each epoch.
Returns:
A list of Keras callbacks to use during training. The list contains a `tf.keras.callbacks.TensorBoard` callback for logging training information, a `jidenn.callbacks.LogCallback.LogCallback` callback for logging training information, a `jidenn.callbacks.BestNModelCheckpoint.BestNModelCheckpoint` callback for saving model checkpoints, and a `tf.keras.callbacks.BackupAndRestore` callback for saving backups of the training session.
"""
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=base_logdir)
log_callback = LogCallback(epochs, log)
callbacks = [tb_callback, log_callback]
if checkpoint is not None:
os.makedirs(os.path.join(base_logdir, checkpoint), exist_ok=True)
base_checkpoints = BestNModelCheckpoint(filepath=os.path.join(base_logdir, checkpoint, 'model-{epoch:02d}-{val_binary_accuracy:.2f}.h5'),
max_to_keep=2,
monitor='val_binary_accuracy',
mode='max',
save_weights_only=True,
save_best_only=True,
save_freq='epoch',
verbose=1,)
callbacks.append(base_checkpoints)
if backup is not None:
os.makedirs(os.path.join(base_logdir, backup), exist_ok=True)
backup_callback = tf.keras.callbacks.BackupAndRestore(backup_dir=os.path.join(
base_logdir, backup), delete_checkpoint=False, save_freq=backup_freq if backup_freq is not None else "epoch")
callbacks.append(backup_callback)
return callbacks
Functions
def get_callbacks(base_logdir: str, epochs: int, log: logging.Logger, checkpoint: Optional[str] = 'checkpoints', backup: Optional[str] = 'backup', backup_freq: Optional[int] = None) ‑> List[keras.callbacks.Callback]
-
Returns a list of Keras callbacks for a training session.
Args
base_log_dir
:str
- The base directory to use for saving logs, checkpoints, and backups.
epochs
:int
- The number of epochs to train for.
log
:logging.Logger
- A logger object to use for logging information during training.
checkpoint
:str
- The directory to use for saving model checkpoints. If None, no checkpoints will be saved.
backup
:str
- The directory to use for saving backups of the training session. If None, no backups will be saved.
backup_freq
:int
, optional- The frequency (in batches) at which to save backups of the training session. If None, backups will only be saved at the end of each epoch.
Returns
A list of Keras callbacks to use during training. The list contains a
tf.keras.callbacks.TensorBoard
callback for logging training information, ajidenn.callbacks.LogCallback.LogCallback
callback for logging training information, ajidenn.callbacks.BestNModelCheckpoint.BestNModelCheckpoint
callback for saving model checkpoints, and atf.keras.callbacks.BackupAndRestore
callback for saving backups of the training session.Expand source code
def get_callbacks(base_logdir: str, epochs: int, log: Logger, checkpoint: Optional[str] = 'checkpoints', backup: Optional[str] = 'backup', backup_freq: Optional[int] = None) -> List[tf.keras.callbacks.Callback]: """ Returns a list of Keras callbacks for a training session. Args: base_log_dir (str): The base directory to use for saving logs, checkpoints, and backups. epochs (int): The number of epochs to train for. log (logging.Logger): A logger object to use for logging information during training. checkpoint (str): The directory to use for saving model checkpoints. If None, no checkpoints will be saved. backup (str): The directory to use for saving backups of the training session. If None, no backups will be saved. backup_freq (int, optional): The frequency (in batches) at which to save backups of the training session. If None, backups will only be saved at the end of each epoch. Returns: A list of Keras callbacks to use during training. The list contains a `tf.keras.callbacks.TensorBoard` callback for logging training information, a `jidenn.callbacks.LogCallback.LogCallback` callback for logging training information, a `jidenn.callbacks.BestNModelCheckpoint.BestNModelCheckpoint` callback for saving model checkpoints, and a `tf.keras.callbacks.BackupAndRestore` callback for saving backups of the training session. """ tb_callback = tf.keras.callbacks.TensorBoard(log_dir=base_logdir) log_callback = LogCallback(epochs, log) callbacks = [tb_callback, log_callback] if checkpoint is not None: os.makedirs(os.path.join(base_logdir, checkpoint), exist_ok=True) base_checkpoints = BestNModelCheckpoint(filepath=os.path.join(base_logdir, checkpoint, 'model-{epoch:02d}-{val_binary_accuracy:.2f}.h5'), max_to_keep=2, monitor='val_binary_accuracy', mode='max', save_weights_only=True, save_best_only=True, save_freq='epoch', verbose=1,) callbacks.append(base_checkpoints) if backup is not None: os.makedirs(os.path.join(base_logdir, backup), exist_ok=True) backup_callback = tf.keras.callbacks.BackupAndRestore(backup_dir=os.path.join( base_logdir, backup), delete_checkpoint=False, save_freq=backup_freq if backup_freq is not None else "epoch") callbacks.append(backup_callback) return callbacks
Classes
class BestNModelCheckpoint (filepath: str, max_to_keep: int, keep_most_recent: bool = True, monitor: str = 'val_loss', mode: Literal['min', 'max'] = 'min', **kwargs)
-
Custom ModelCheckpoint Callback that saves the best N checkpoints based on a specified monitor metric.
Original source: https://github.com/schustmi/tf_utils/blob/915fe5e231ca302b28cd02dc8ac2e4c772a62e0b/tf_utils/callbacks.py#L34
This callback is a modification of the
ModelCheckpoint
callback in TensorFlow Keras, allowing for keeping the best N checkpoints based on a specified monitor metric. The checkpoints are saved in a directory specified byfilepath
argument, and the N best checkpoints are kept based on the value of the monitor metric. Optionally, the most recent checkpoint can also be kept.Args
filepath
:str
- path where the checkpoints will be saved. Make sure
to pass a format string as otherwise the checkpoint will be overridden each step.
(see
tf.keras.callbacks.ModelCheckpoint
for detailed formatting options) max_to_keep
:int
- Maximum number of best checkpoints to keep.
keep_most_recent
:bool
- if True, the most recent checkpoint will be saved in addition to
the best
max_to_keep
ones. monitor
:str
- Name of the metric to monitor. The value for this metric will be used to
decide which checkpoints will be kept. See
keras.callbacks.ModelCheckpoint
for more information. mode
:str
- Depending on
mode
, the checkpoints with the highest ('max') or lowest ('min') values in the monitored quantity will be kept. **kwargs
- Additional keyword arguments to pass to the parent class
tf.keras.callbacks.ModelCheckpoint
.
Expand source code
class BestNModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): """Custom ModelCheckpoint Callback that saves the best N checkpoints based on a specified monitor metric. Original source: https://github.com/schustmi/tf_utils/blob/915fe5e231ca302b28cd02dc8ac2e4c772a62e0b/tf_utils/callbacks.py#L34 This callback is a modification of the `ModelCheckpoint` callback in TensorFlow Keras, allowing for keeping the best N checkpoints based on a specified monitor metric. The checkpoints are saved in a directory specified by `filepath` argument, and the N best checkpoints are kept based on the value of the monitor metric. Optionally, the most recent checkpoint can also be kept. Args: filepath (str): path where the checkpoints will be saved. Make sure to pass a format string as otherwise the checkpoint will be overridden each step. (see `tf.keras.callbacks.ModelCheckpoint` for detailed formatting options) max_to_keep (int): Maximum number of best checkpoints to keep. keep_most_recent (bool): if True, the most recent checkpoint will be saved in addition to the best `max_to_keep` ones. monitor (str): Name of the metric to monitor. The value for this metric will be used to decide which checkpoints will be kept. See `keras.callbacks.ModelCheckpoint` for more information. mode (str): Depending on `mode`, the checkpoints with the highest ('max') or lowest ('min') values in the monitored quantity will be kept. **kwargs: Additional keyword arguments to pass to the parent class `tf.keras.callbacks.ModelCheckpoint`. """ def __init__(self, filepath: str, max_to_keep: int, keep_most_recent: bool = True, monitor: str = 'val_loss', mode: Literal['min', 'max'] = 'min', **kwargs): if kwargs.pop('save_best_only', None): log.warning("Setting `save_best_only` to False.") if max_to_keep < 1: log.warning("BestNModelCheckpoint parameter `max_to_keep` must be greater than 0, setting it to 1.") max_to_keep = 1 super().__init__(filepath, monitor=monitor, mode=mode, save_best_only=False, **kwargs) self._keep_count = max_to_keep self._checkpoints = {} self._keep_most_recent = keep_most_recent if self._keep_most_recent: self._most_recent_checkpoint = None def _save_model(self, epoch, batch, logs): super()._save_model(epoch, batch, logs) logs = tf_utils.sync_to_numpy_or_python_type(logs or {}) filepath = self._get_file_path(epoch, batch, logs) if not self._checkpoint_exists(filepath): # Did not save a checkpoint for current epoch return value = logs.get(self.monitor) if self._keep_most_recent: # delay adding to list of current checkpoints until next save # if we should always keep the most recent checkpoint if self._most_recent_checkpoint: self._checkpoints.update(self._most_recent_checkpoint) self._most_recent_checkpoint = {filepath: value} else: self._checkpoints[filepath] = value if len(self._checkpoints) > self._keep_count: self._delete_worst_checkpoint() def _delete_worst_checkpoint(self): worst_checkpoint = None # tuple (filepath, value) for checkpoint in self._checkpoints.items(): if worst_checkpoint is None or self.monitor_op(worst_checkpoint[1], checkpoint[1]): worst_checkpoint = checkpoint self._checkpoints.pop(worst_checkpoint[0]) self._delete_checkpoint_files(worst_checkpoint[0]) @staticmethod def _delete_checkpoint_files(checkpoint_path): log.info(f"Removing files for checkpoint '{checkpoint_path}'") if os.path.isdir(checkpoint_path): # SavedModel format delete the whole directory shutil.rmtree(checkpoint_path) return for f in tf.io.gfile.glob(checkpoint_path + '*'): os.remove(f)
Ancestors
- keras.callbacks.ModelCheckpoint
- keras.callbacks.Callback
class LogCallback (epochs: int, logger: logging.Logger)
-
Callback to log the training progress to a specified logging.Logger object. Logs the total time elapsed during the epoch (ETA).
Args
epochs
:int
- The total number of epochs.
logger
:logging.Logger
- The logging object to use.
Example
An example output of the logger:
Epoch 1/5: ETA: 01:18:43 - loss: 0.5569 - accuracy: 0.7089 - val_loss: 0.533 - val_accuracy: 0.7307 Epoch 2/5: ETA: 01:15:50 - loss: 0.5333 - accuracy: 0.7289 - val_loss: 0.5291 - val_accuracy: 0.7328 Epoch 3/5: ETA: 01:15:12 - loss: 0.5304 - accuracy: 0.7311 - val_loss: 0.5266 - val_accuracy: 0.7344 Epoch 4/5: ETA: 01:15:18 - loss: 0.5281 - accuracy: 0.7328 - val_loss: 0.525 - val_accuracy: 0.7351 Epoch 5/5: ETA: 01:14:56 - loss: 0.5262 - accuracy: 0.7342 - val_loss: 0.5231 - val_accuracy: 0.7374
Expand source code
class LogCallback(tf.keras.callbacks.Callback): """ Callback to log the training progress to a specified logging.Logger object. Logs the total time elapsed during the epoch (ETA). Args: epochs (int): The total number of epochs. logger (logging.Logger): The logging object to use. Example: An example output of the logger: ``` Epoch 1/5: ETA: 01:18:43 - loss: 0.5569 - accuracy: 0.7089 - val_loss: 0.533 - val_accuracy: 0.7307 Epoch 2/5: ETA: 01:15:50 - loss: 0.5333 - accuracy: 0.7289 - val_loss: 0.5291 - val_accuracy: 0.7328 Epoch 3/5: ETA: 01:15:12 - loss: 0.5304 - accuracy: 0.7311 - val_loss: 0.5266 - val_accuracy: 0.7344 Epoch 4/5: ETA: 01:15:18 - loss: 0.5281 - accuracy: 0.7328 - val_loss: 0.525 - val_accuracy: 0.7351 Epoch 5/5: ETA: 01:14:56 - loss: 0.5262 - accuracy: 0.7342 - val_loss: 0.5231 - val_accuracy: 0.7374 ``` """ def __init__(self, epochs: int, logger: Logger) -> None: self.epoch_start_time = None self.specified_logger = logger self.total_epochs = epochs def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None) -> None: """ Called at the beginning of every epoch. Sets the epoch start time. Args: epoch (int): Index of epoch. logs (dict, optional): Dictionary of logs for the current epoch. Returns: None """ self.epoch_start_time = datetime.utcnow() def on_epoch_end(self, epoch: int, logs: dict) -> None: """ Called at the end of every epoch. Logs the epoch progress and calculates the total time elapsed during the epoch (ETA). Args: epoch (int): Index of epoch. logs (dict): Dictionary of logs for the current epoch. Returns: None """ eta = datetime.utcnow() - self.epoch_start_time if self.epoch_start_time is not None else '-:--:--' eta = datetime.strptime(str(eta), "%H:%M:%S.%f") log_str = f"Epoch {epoch+1}/{self.total_epochs}: " log_str += f"ETA: {eta:%T} - " log_str += " - ".join([f"{k}: {v:.4}" for k, v in logs.items()]) self.specified_logger.info(log_str)
Ancestors
- keras.callbacks.Callback
Methods
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None) ‑> None
-
Called at the beginning of every epoch. Sets the epoch start time.
Args
epoch
:int
- Index of epoch.
logs
:dict
, optional- Dictionary of logs for the current epoch.
Returns
None
Expand source code
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None) -> None: """ Called at the beginning of every epoch. Sets the epoch start time. Args: epoch (int): Index of epoch. logs (dict, optional): Dictionary of logs for the current epoch. Returns: None """ self.epoch_start_time = datetime.utcnow()
def on_epoch_end(self, epoch: int, logs: dict) ‑> None
-
Called at the end of every epoch. Logs the epoch progress and calculates the total time elapsed during the epoch (ETA).
Args
epoch
:int
- Index of epoch.
logs
:dict
- Dictionary of logs for the current epoch.
Returns
None
Expand source code
def on_epoch_end(self, epoch: int, logs: dict) -> None: """ Called at the end of every epoch. Logs the epoch progress and calculates the total time elapsed during the epoch (ETA). Args: epoch (int): Index of epoch. logs (dict): Dictionary of logs for the current epoch. Returns: None """ eta = datetime.utcnow() - self.epoch_start_time if self.epoch_start_time is not None else '-:--:--' eta = datetime.strptime(str(eta), "%H:%M:%S.%f") log_str = f"Epoch {epoch+1}/{self.total_epochs}: " log_str += f"ETA: {eta:%T} - " log_str += " - ".join([f"{k}: {v:.4}" for k, v in logs.items()]) self.specified_logger.info(log_str)