Module jidenn.model_builders.callbacks

Module containing the custom callbacks and a function that returns a list of callbacks to be used during the training process.

Module containing the custom callbacks and a function that returns 
a list of callbacks to be used during the training process.

import tensorflow as tf
import os
from logging import Logger
from typing import List, Optional, Literal
from tensorflow.python.keras.utils import tf_utils
import tensorflow as tf
import shutil
import logging
from datetime import datetime

log = logging.getLogger(__name__)

class LogCallback(tf.keras.callbacks.Callback):
    Callback to log the training progress to a specified logging.Logger object. Logs the total time elapsed during the epoch (ETA).

        epochs (int): The total number of epochs.
        logger (logging.Logger): The logging object to use.

        An example output of the logger:

        Epoch 1/5: ETA: 01:18:43 - loss: 0.5569 - accuracy: 0.7089 - val_loss: 0.533 - val_accuracy: 0.7307
        Epoch 2/5: ETA: 01:15:50 - loss: 0.5333 - accuracy: 0.7289 - val_loss: 0.5291 - val_accuracy: 0.7328
        Epoch 3/5: ETA: 01:15:12 - loss: 0.5304 - accuracy: 0.7311 - val_loss: 0.5266 - val_accuracy: 0.7344
        Epoch 4/5: ETA: 01:15:18 - loss: 0.5281 - accuracy: 0.7328 - val_loss: 0.525 - val_accuracy: 0.7351
        Epoch 5/5: ETA: 01:14:56 - loss: 0.5262 - accuracy: 0.7342 - val_loss: 0.5231 - val_accuracy: 0.7374


    def __init__(self, epochs: int, logger: Logger) -> None:
        self.epoch_start_time = None
        self.specified_logger = logger
        self.total_epochs = epochs

    def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None) -> None:
        Called at the beginning of every epoch. Sets the epoch start time.

            epoch (int): Index of epoch.
            logs (dict, optional): Dictionary of logs for the current epoch.

        self.epoch_start_time = datetime.utcnow()

    def on_epoch_end(self, epoch: int, logs: dict) -> None:
        Called at the end of every epoch. Logs the epoch progress and calculates the total time elapsed during the epoch (ETA).

            epoch (int): Index of epoch.
            logs (dict): Dictionary of logs for the current epoch.


        eta = datetime.utcnow() - self.epoch_start_time if self.epoch_start_time is not None else '-:--:--'
        eta = datetime.strptime(str(eta), "%H:%M:%S.%f")

        log_str = f"Epoch {epoch+1}/{self.total_epochs}: "
        log_str += f"ETA: {eta:%T} - "
        log_str += " - ".join([f"{k}: {v:.4}" for k, v in logs.items()])

class BestNModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
    """Custom ModelCheckpoint Callback that saves the best N checkpoints based on a specified monitor metric.

    Original source:

    This callback is a modification of the `ModelCheckpoint` callback in TensorFlow Keras, allowing for keeping
    the best N checkpoints based on a specified monitor metric. The checkpoints are saved in a directory specified
    by `filepath` argument, and the N best checkpoints are kept based on the value of the monitor metric. Optionally,
    the most recent checkpoint can also be kept.

        filepath (str): path where the checkpoints will be saved. Make sure
            to pass a format string as otherwise the checkpoint will be overridden each step.
            (see `tf.keras.callbacks.ModelCheckpoint` for detailed formatting options)
        max_to_keep (int): Maximum number of best checkpoints to keep.
        keep_most_recent (bool): if True, the most recent checkpoint will be saved in addition to
            the best `max_to_keep` ones.
        monitor (str): Name of the metric to monitor. The value for this metric will be used to
            decide which checkpoints will be kept. See `keras.callbacks.ModelCheckpoint` for
            more information.
        mode (str): Depending on `mode`, the checkpoints with the highest ('max') 
            or lowest ('min') values in the monitored quantity will be kept.
        **kwargs: Additional keyword arguments to pass to the parent class `tf.keras.callbacks.ModelCheckpoint`.


    def __init__(self,
                 filepath: str,
                 max_to_keep: int,
                 keep_most_recent: bool = True,
                 monitor: str = 'val_loss',
                 mode: Literal['min', 'max'] = 'min',
        if kwargs.pop('save_best_only', None):
            log.warning("Setting `save_best_only` to False.")

        if max_to_keep < 1:
            log.warning("BestNModelCheckpoint parameter `max_to_keep` must be greater than 0, setting it to 1.")
            max_to_keep = 1

        self._keep_count = max_to_keep
        self._checkpoints = {}

        self._keep_most_recent = keep_most_recent
        if self._keep_most_recent:
            self._most_recent_checkpoint = None

    def _save_model(self, epoch, batch, logs):
        super()._save_model(epoch, batch, logs)
        logs = tf_utils.sync_to_numpy_or_python_type(logs or {})
        filepath = self._get_file_path(epoch, batch, logs)

        if not self._checkpoint_exists(filepath):
            # Did not save a checkpoint for current epoch

        value = logs.get(self.monitor)

        if self._keep_most_recent:
            # delay adding to list of current checkpoints until next save
            # if we should always keep the most recent checkpoint
            if self._most_recent_checkpoint:
            self._most_recent_checkpoint = {filepath: value}
            self._checkpoints[filepath] = value

        if len(self._checkpoints) > self._keep_count:

    def _delete_worst_checkpoint(self):
        worst_checkpoint = None  # tuple (filepath, value)

        for checkpoint in self._checkpoints.items():
            if worst_checkpoint is None or self.monitor_op(worst_checkpoint[1], checkpoint[1]):
                worst_checkpoint = checkpoint


    def _delete_checkpoint_files(checkpoint_path):"Removing files for checkpoint '{checkpoint_path}'")

        if os.path.isdir(checkpoint_path):
            # SavedModel format delete the whole directory

        for f in + '*'):

def get_callbacks(base_logdir: str,
                  epochs: int,
                  log: Logger,
                  checkpoint: Optional[str] = 'checkpoints',
                  backup: Optional[str] = 'backup',
                  backup_freq: Optional[int] = None) -> List[tf.keras.callbacks.Callback]:
    Returns a list of Keras callbacks for a training session.

        base_log_dir (str): The base directory to use for saving logs, checkpoints, and backups.
        epochs (int): The number of epochs to train for.
        log (logging.Logger): A logger object to use for logging information during training.
        checkpoint (str): The directory to use for saving model checkpoints. If None, no checkpoints will be saved.
        backup (str): The directory to use for saving backups of the training session. If None, no backups will be saved.
        backup_freq (int, optional): The frequency (in batches) at which to save backups of the training session. If None, backups will only be saved at the end of each epoch.

        A list of Keras callbacks to use during training. The list contains a `tf.keras.callbacks.TensorBoard` callback for logging training information, a `jidenn.callbacks.LogCallback.LogCallback` callback for logging training information, a `jidenn.callbacks.BestNModelCheckpoint.BestNModelCheckpoint` callback for saving model checkpoints, and a `tf.keras.callbacks.BackupAndRestore` callback for saving backups of the training session.

    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=base_logdir)
    log_callback = LogCallback(epochs, log)
    callbacks = [tb_callback, log_callback]

    if checkpoint is not None:
        os.makedirs(os.path.join(base_logdir, checkpoint), exist_ok=True)
        base_checkpoints = BestNModelCheckpoint(filepath=os.path.join(base_logdir, checkpoint, 'model-{epoch:02d}-{val_binary_accuracy:.2f}.h5'),

    if backup is not None:
        os.makedirs(os.path.join(base_logdir, backup), exist_ok=True)
        backup_callback = tf.keras.callbacks.BackupAndRestore(backup_dir=os.path.join(
            base_logdir, backup), delete_checkpoint=False, save_freq=backup_freq if backup_freq is not None else "epoch")

    return callbacks


