Module jidenn.model_builders.normalization_initialization
Module for initializing normalization layers and adapting them to the dataset, i.e. calculating the mean and std of the dataset.
Expand source code
"""
Module for initializing normalization layers and adapting them to the dataset,
i.e. calculating the mean and std of the dataset.
"""
import tensorflow as tf
import logging
from typing import Optional, Tuple, Union, Literal
def get_normalization(dataset: tf.data.Dataset,
log: logging.Logger,
adapt: bool = True,
ragged: bool = False,
normalization_steps: Optional[int] = None,
interaction: Optional[bool] = None,) -> Union[tf.keras.layers.Normalization, Tuple[tf.keras.layers.Normalization, tf.keras.layers.Normalization]]:
"""Function returning normalization layer(s) adapted to the dataset if requested.
Unadapted normalization layer(s) are asumed to have weights loaded from checkpoint.
Args:
dataset (tf.data.Dataset): Dataset to adapt the normalization layer to.
log (logging.Logger): Logger.
adapt (bool, optional): Whether to adapt the normalization layer to the dataset. Defaults to True.
ragged (bool, optional): Whether the dataset samples are ragged. Defaults to False.
normalization_steps (int, optional): Number of batches to use for adaptation. Defaults to None.
interaction (bool, optional): Whether to adapt the normalization layer to the interaction variables. Defaults to None.
Returns:
Union[tf.keras.layers.Normalization, Tuple[tf.keras.layers.Normalization, tf.keras.layers.Normalization]]: Return `None` if the model is `bdt` or unknown model,
one layer if there is no interaction, two layers as a `tuple` if there is interaction.
"""
if not adapt:
log.warning("Normalization not adapting. Loading weights expected.")
if not interaction:
return tf.keras.layers.Normalization(axis=-1)
else:
return tf.keras.layers.Normalization(axis=-1), tf.keras.layers.Normalization(axis=-1)
normalizer = tf.keras.layers.Normalization(axis=-1)
log.info("Getting std and mean of the dataset...")
log.info(f"Subsample size (num of batches): {normalization_steps}")
if not ragged:
def picker(*x):
return x[0]
elif interaction:
def picker(*x):
return x[0][0].to_tensor()
else:
def picker(*x):
return x[0].to_tensor()
try:
normalizer.adapt(dataset.map(picker), steps=normalization_steps)
except Exception as e:
log.error(f"Normalization failed: {e}")
normalizer = None
if not interaction:
log.info("No interaction normalization")
return normalizer
log.info("Getting std and mean of the **interaction** dataset...")
def picker_interaction(*x):
inputs = x[0][1].to_tensor()
ones = tf.ones_like(inputs[:, :, 0])
upper_tril_mask = tf.linalg.band_part(ones, 0, -1)
diag_mask = tf.linalg.band_part(ones, 0, 0)
upper_tril_mask = tf.cast(upper_tril_mask - diag_mask, tf.bool)
flattened_upper_triag = tf.boolean_mask(inputs, upper_tril_mask)
return flattened_upper_triag
normalizer_interaction = tf.keras.layers.Normalization(axis=-1)
try:
normalizer_interaction.adapt(dataset.map(
picker_interaction), steps=normalization_steps)
except Exception as e:
log.error(f"Interaction normalization failed: {e}")
normalizer_interaction = None
return normalizer, normalizer_interaction
Functions
def get_normalization(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, log: logging.Logger, adapt: bool = True, ragged: bool = False, normalization_steps: Optional[int] = None, interaction: Optional[bool] = None) ‑> Union[keras.layers.preprocessing.normalization.Normalization, Tuple[keras.layers.preprocessing.normalization.Normalization, keras.layers.preprocessing.normalization.Normalization]]
-
Function returning normalization layer(s) adapted to the dataset if requested. Unadapted normalization layer(s) are asumed to have weights loaded from checkpoint.
Args
dataset
:tf.data.Dataset
- Dataset to adapt the normalization layer to.
log
:logging.Logger
- Logger.
adapt
:bool
, optional- Whether to adapt the normalization layer to the dataset. Defaults to True.
ragged
:bool
, optional- Whether the dataset samples are ragged. Defaults to False.
normalization_steps
:int
, optional- Number of batches to use for adaptation. Defaults to None.
interaction
:bool
, optional- Whether to adapt the normalization layer to the interaction variables. Defaults to None.
Returns
Union[tf.keras.layers.Normalization, Tuple[tf.keras.layers.Normalization, tf.keras.layers.Normalization]]
- Return
None
if the model isbdt
or unknown model, one layer if there is no interaction, two layers as atuple
if there is interaction.
Expand source code
def get_normalization(dataset: tf.data.Dataset, log: logging.Logger, adapt: bool = True, ragged: bool = False, normalization_steps: Optional[int] = None, interaction: Optional[bool] = None,) -> Union[tf.keras.layers.Normalization, Tuple[tf.keras.layers.Normalization, tf.keras.layers.Normalization]]: """Function returning normalization layer(s) adapted to the dataset if requested. Unadapted normalization layer(s) are asumed to have weights loaded from checkpoint. Args: dataset (tf.data.Dataset): Dataset to adapt the normalization layer to. log (logging.Logger): Logger. adapt (bool, optional): Whether to adapt the normalization layer to the dataset. Defaults to True. ragged (bool, optional): Whether the dataset samples are ragged. Defaults to False. normalization_steps (int, optional): Number of batches to use for adaptation. Defaults to None. interaction (bool, optional): Whether to adapt the normalization layer to the interaction variables. Defaults to None. Returns: Union[tf.keras.layers.Normalization, Tuple[tf.keras.layers.Normalization, tf.keras.layers.Normalization]]: Return `None` if the model is `bdt` or unknown model, one layer if there is no interaction, two layers as a `tuple` if there is interaction. """ if not adapt: log.warning("Normalization not adapting. Loading weights expected.") if not interaction: return tf.keras.layers.Normalization(axis=-1) else: return tf.keras.layers.Normalization(axis=-1), tf.keras.layers.Normalization(axis=-1) normalizer = tf.keras.layers.Normalization(axis=-1) log.info("Getting std and mean of the dataset...") log.info(f"Subsample size (num of batches): {normalization_steps}") if not ragged: def picker(*x): return x[0] elif interaction: def picker(*x): return x[0][0].to_tensor() else: def picker(*x): return x[0].to_tensor() try: normalizer.adapt(dataset.map(picker), steps=normalization_steps) except Exception as e: log.error(f"Normalization failed: {e}") normalizer = None if not interaction: log.info("No interaction normalization") return normalizer log.info("Getting std and mean of the **interaction** dataset...") def picker_interaction(*x): inputs = x[0][1].to_tensor() ones = tf.ones_like(inputs[:, :, 0]) upper_tril_mask = tf.linalg.band_part(ones, 0, -1) diag_mask = tf.linalg.band_part(ones, 0, 0) upper_tril_mask = tf.cast(upper_tril_mask - diag_mask, tf.bool) flattened_upper_triag = tf.boolean_mask(inputs, upper_tril_mask) return flattened_upper_triag normalizer_interaction = tf.keras.layers.Normalization(axis=-1) try: normalizer_interaction.adapt(dataset.map( picker_interaction), steps=normalization_steps) except Exception as e: log.error(f"Interaction normalization failed: {e}") normalizer_interaction = None return normalizer, normalizer_interaction