Module jidenn.model_builders.model_initialization
Module for initializing models using the configuration file.
The config options are explicitly passed to the model classes __init__
method.
For each model, a function is defined that returns an instance of the model class.
The mapping from string names of the models in the config file is then mapped to the corresponding function.
Expand source code
"""
Module for initializing models using the configuration file.
The config options are explicitly passed to the model classes `__init__` method.
For each model, a function is defined that returns an instance of the model class.
The mapping from string names of the models in the config file is then mapped to the corresponding function.
"""
import tensorflow as tf
import tensorflow_addons as tfa
from typing import Tuple, Optional, Callable, List, Literal, Union, Any
from jidenn.config import model_config
from ..models.FC import FCModel
from ..models.Highway import HighwayModel
from ..models.Transformer import TransformerModel
from ..models.ParT import ParTModel
from ..models.DeParT import DeParTModel
from ..models.PFN import PFNModel
from ..models.EFN import EFNModel
from ..models.BDT import bdt_model
def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) -> Callable[[tf.Tensor], tf.Tensor]:
"""Map string names of activation functions to the corresponding function.
Args:
activation (str): Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish'].
Returns:
Callable[[tf.Tensor], tf.Tensor]: Activation function.
"""
if activation == 'relu':
return tf.nn.relu
elif activation == 'gelu':
return tfa.activations.gelu
elif activation == 'tanh':
return tf.nn.tanh
elif activation == 'swish':
return tf.keras.activations.swish
else:
raise NotImplementedError(f'Activation {activation} not supported.')
def get_FC_model(input_size: int,
output_layer: tf.keras.layers.Layer,
args_model: model_config.FC,
preprocess: Optional[tf.keras.layers.Layer] = None) -> FCModel:
"""Get an instance of the fully-connected model.
Args:
input_size (int): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (jidenn.config.model_config.FC): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
FCModel: Fully-connected model.
"""
return FCModel(
layer_size=args_model.layer_size,
num_layers=args_model.num_layers,
dropout=args_model.dropout,
input_size=input_size,
output_layer=output_layer,
activation=get_activation(args_model.activation),
preprocess=preprocess)
def get_highway_model(input_size: int,
output_layer: tf.keras.layers.Layer,
args_model: model_config.Highway,
preprocess: Optional[tf.keras.layers.Layer] = None) -> HighwayModel:
"""Get an instance of the highway model.
Args:
input_size (int): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (jidenn.config.model_config.Highway): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
HighwayModel: Highway model.
"""
return HighwayModel(
layer_size=args_model.layer_size,
num_layers=args_model.num_layers,
dropout=args_model.dropout,
input_size=input_size,
output_layer=output_layer,
activation=get_activation(args_model.activation),
preprocess=preprocess)
def get_pfn_model(input_size: Tuple[None, int],
output_layer: tf.keras.layers.Layer,
args_model: model_config.PFN,
preprocess: Optional[tf.keras.layers.Layer] = None) -> PFNModel:
"""Get an instance of the PFN model.
Args:
input_size (Tuple[None, int]): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (model_config.PFN): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
PFNModel: PFN model.
"""
return PFNModel(
input_shape=input_size,
Phi_sizes=args_model.Phi_sizes,
F_sizes=args_model.F_sizes,
Phi_backbone=args_model.Phi_backbone,
activation=get_activation(args_model.activation),
batch_norm=args_model.batch_norm,
preprocess=preprocess,
output_layer=output_layer)
def get_efn_model(input_size: Tuple[None, int],
output_layer: tf.keras.layers.Layer,
args_model: model_config.EFN,
preprocess: Optional[tf.keras.layers.Layer] = None) -> EFNModel:
"""Get an instance of the EFN model.
Args:
input_size (Tuple[None, int]): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (model_config.EFN): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
EFNModel: EFN model.
"""
return EFNModel(
input_shape=input_size,
Phi_sizes=args_model.Phi_sizes,
F_sizes=args_model.F_sizes,
Phi_backbone=args_model.Phi_backbone,
activation=get_activation(args_model.activation),
batch_norm=args_model.batch_norm,
preprocess=preprocess,
output_layer=output_layer)
def get_transformer_model(input_size: Tuple[None, int],
output_layer: tf.keras.layers.Layer,
args_model: model_config.Transformer,
preprocess: Optional[tf.keras.layers.Layer] = None) -> TransformerModel:
"""Get an instance of the transformer model.
Args:
input_size (Tuple[None, int]): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (model_config.Transformer): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
TransformerModel: Transformer model.
"""
return TransformerModel(
embed_dim=args_model.embed_dim,
embed_layers=args_model.embed_layers,
dropout=args_model.dropout,
expansion=args_model.expansion,
heads=args_model.heads,
self_attn_layers=args_model.self_attn_layers,
activation=get_activation(args_model.activation),
input_shape=input_size,
output_layer=output_layer,
preprocess=preprocess)
def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
output_layer: tf.keras.layers.Layer,
args_model: model_config.ParT,
preprocess: Optional[tf.keras.layers.Layer] = None) -> ParTModel:
"""Get an instance of the ParT model.
Args:
input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (model_config.ParT): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
ParTModel: ParT model.
"""
return ParTModel(
input_shape=input_size,
output_layer=output_layer,
#
embed_dim=args_model.embed_dim,
embed_layers=args_model.embed_layers,
self_attn_layers=args_model.self_attn_layers,
class_attn_layers=args_model.class_attn_layers,
expansion=args_model.expansion,
heads=args_model.heads,
dropout=args_model.dropout,
interaction_embed_layers=args_model.interaction_embedding_layers,
interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
#
preprocess=preprocess,
activation=get_activation(args_model.activation))
def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
output_layer: tf.keras.layers.Layer,
args_model: model_config.DeParT,
preprocess: Optional[tf.keras.layers.Layer] = None) -> DeParTModel:
"""Get an instance of the DeParT model.
Args:
input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
output_layer (tf.keras.layers.Layer): Output layer of the model.
args_model (model_config.DeParT): Model configuration.
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.
Returns:
DeParTModel: DeParT model.
"""
return DeParTModel(
input_shape=input_size,
output_layer=output_layer,
#
embed_dim=args_model.embed_dim,
embed_layers=args_model.embed_layers,
self_attn_layers=args_model.self_attn_layers,
class_attn_layers=args_model.class_attn_layers,
expansion=args_model.expansion,
heads=args_model.heads,
dropout=args_model.dropout,
layer_scale_init_value=args_model.layer_scale_init_value,
stochastic_depth_drop_rate=args_model.stochastic_depth_drop_rate,
class_dropout=args_model.class_dropout,
class_stochastic_depth_drop_rate=args_model.class_stochastic_depth_drop_rate,
interaction_embed_layers=args_model.interaction_embedding_layers,
interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
#
preprocess=preprocess,
activation=get_activation(args_model.activation))
def get_bdt_model(input_size: int,
output_layer: tf.keras.layers.Layer,
args_model: model_config.BDT,
preprocess: Optional[tf.keras.layers.Layer] = None):
"""Get an instance of the BDT model.
Args:
args_model (model_config.BDT): Model configuration.
input_size (int): Input size of the model. UNUSED
output_layer (tf.keras.layers.Layer): Output layer of the model. UNUSED
preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. UNUSED
Returns:
tfdf.keras.GradientBoostedTreesModel: BDT model.
"""
return bdt_model(args_model)
def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt']
) -> Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]:
"""Get a model getter function.
Args:
model_name (str): Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'.
Returns:
Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]: Model getter function.
"""
lookup_model = {'fc': get_FC_model,
'highway': get_highway_model,
'pfn': get_pfn_model,
'efn': get_efn_model,
'transformer': get_transformer_model,
'part': get_part_model,
'depart': get_depart_model,
'bdt': get_bdt_model, }
if model_name not in lookup_model:
raise ValueError(f'Unknown model {model_name}')
return lookup_model[model_name]
Functions
def get_FC_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: FC, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> FCModel
-
Get an instance of the fully-connected model.
Args
input_size
:int
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:FC
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
FCModel
- Fully-connected model.
Expand source code
def get_FC_model(input_size: int, output_layer: tf.keras.layers.Layer, args_model: model_config.FC, preprocess: Optional[tf.keras.layers.Layer] = None) -> FCModel: """Get an instance of the fully-connected model. Args: input_size (int): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (jidenn.config.model_config.FC): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: FCModel: Fully-connected model. """ return FCModel( layer_size=args_model.layer_size, num_layers=args_model.num_layers, dropout=args_model.dropout, input_size=input_size, output_layer=output_layer, activation=get_activation(args_model.activation), preprocess=preprocess)
def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) ‑> Callable[[tensorflow.python.framework.ops.Tensor], tensorflow.python.framework.ops.Tensor]
-
Map string names of activation functions to the corresponding function.
Args
activation
:str
- Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish'].
Returns
Callable[[tf.Tensor], tf.Tensor]
- Activation function.
Expand source code
def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) -> Callable[[tf.Tensor], tf.Tensor]: """Map string names of activation functions to the corresponding function. Args: activation (str): Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish']. Returns: Callable[[tf.Tensor], tf.Tensor]: Activation function. """ if activation == 'relu': return tf.nn.relu elif activation == 'gelu': return tfa.activations.gelu elif activation == 'tanh': return tf.nn.tanh elif activation == 'swish': return tf.keras.activations.swish else: raise NotImplementedError(f'Activation {activation} not supported.')
def get_bdt_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: BDT, preprocess: Optional[keras.engine.base_layer.Layer] = None)
-
Get an instance of the BDT model.
Args
args_model
:model_config.BDT
- Model configuration.
input_size
:int
- Input size of the model. UNUSED
output_layer
:tf.keras.layers.Layer
- Output layer of the model. UNUSED
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None. UNUSED
Returns
tfdf.keras.GradientBoostedTreesModel
- BDT model.
Expand source code
def get_bdt_model(input_size: int, output_layer: tf.keras.layers.Layer, args_model: model_config.BDT, preprocess: Optional[tf.keras.layers.Layer] = None): """Get an instance of the BDT model. Args: args_model (model_config.BDT): Model configuration. input_size (int): Input size of the model. UNUSED output_layer (tf.keras.layers.Layer): Output layer of the model. UNUSED preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. UNUSED Returns: tfdf.keras.GradientBoostedTreesModel: BDT model. """ return bdt_model(args_model)
def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: keras.engine.base_layer.Layer, args_model: DeParT, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> DeParTModel
-
Get an instance of the DeParT model.
Args
input_size
:Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:model_config.DeParT
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
DeParTModel
- DeParT model.
Expand source code
def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: tf.keras.layers.Layer, args_model: model_config.DeParT, preprocess: Optional[tf.keras.layers.Layer] = None) -> DeParTModel: """Get an instance of the DeParT model. Args: input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (model_config.DeParT): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: DeParTModel: DeParT model. """ return DeParTModel( input_shape=input_size, output_layer=output_layer, # embed_dim=args_model.embed_dim, embed_layers=args_model.embed_layers, self_attn_layers=args_model.self_attn_layers, class_attn_layers=args_model.class_attn_layers, expansion=args_model.expansion, heads=args_model.heads, dropout=args_model.dropout, layer_scale_init_value=args_model.layer_scale_init_value, stochastic_depth_drop_rate=args_model.stochastic_depth_drop_rate, class_dropout=args_model.class_dropout, class_stochastic_depth_drop_rate=args_model.class_stochastic_depth_drop_rate, interaction_embed_layers=args_model.interaction_embedding_layers, interaction_embed_layer_size=args_model.interaction_embedding_layer_size, # preprocess=preprocess, activation=get_activation(args_model.activation))
def get_efn_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: EFN, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> EFNModel
-
Get an instance of the EFN model.
Args
input_size
:Tuple[None, int]
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:model_config.EFN
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
EFNModel
- EFN model.
Expand source code
def get_efn_model(input_size: Tuple[None, int], output_layer: tf.keras.layers.Layer, args_model: model_config.EFN, preprocess: Optional[tf.keras.layers.Layer] = None) -> EFNModel: """Get an instance of the EFN model. Args: input_size (Tuple[None, int]): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (model_config.EFN): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: EFNModel: EFN model. """ return EFNModel( input_shape=input_size, Phi_sizes=args_model.Phi_sizes, F_sizes=args_model.F_sizes, Phi_backbone=args_model.Phi_backbone, activation=get_activation(args_model.activation), batch_norm=args_model.batch_norm, preprocess=preprocess, output_layer=output_layer)
def get_highway_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: Highway, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> HighwayModel
-
Get an instance of the highway model.
Args
input_size
:int
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:Highway
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
HighwayModel
- Highway model.
Expand source code
def get_highway_model(input_size: int, output_layer: tf.keras.layers.Layer, args_model: model_config.Highway, preprocess: Optional[tf.keras.layers.Layer] = None) -> HighwayModel: """Get an instance of the highway model. Args: input_size (int): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (jidenn.config.model_config.Highway): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: HighwayModel: Highway model. """ return HighwayModel( layer_size=args_model.layer_size, num_layers=args_model.num_layers, dropout=args_model.dropout, input_size=input_size, output_layer=output_layer, activation=get_activation(args_model.activation), preprocess=preprocess)
def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: keras.engine.base_layer.Layer, args_model: ParT, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> ParTModel
-
Get an instance of the ParT model.
Args
input_size
:Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:model_config.ParT
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
ParTModel
- ParT model.
Expand source code
def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: tf.keras.layers.Layer, args_model: model_config.ParT, preprocess: Optional[tf.keras.layers.Layer] = None) -> ParTModel: """Get an instance of the ParT model. Args: input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (model_config.ParT): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: ParTModel: ParT model. """ return ParTModel( input_shape=input_size, output_layer=output_layer, # embed_dim=args_model.embed_dim, embed_layers=args_model.embed_layers, self_attn_layers=args_model.self_attn_layers, class_attn_layers=args_model.class_attn_layers, expansion=args_model.expansion, heads=args_model.heads, dropout=args_model.dropout, interaction_embed_layers=args_model.interaction_embedding_layers, interaction_embed_layer_size=args_model.interaction_embedding_layer_size, # preprocess=preprocess, activation=get_activation(args_model.activation))
def get_pfn_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: PFN, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> PFNModel
-
Get an instance of the PFN model.
Args
input_size
:Tuple[None, int]
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:model_config.PFN
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
PFNModel
- PFN model.
Expand source code
def get_pfn_model(input_size: Tuple[None, int], output_layer: tf.keras.layers.Layer, args_model: model_config.PFN, preprocess: Optional[tf.keras.layers.Layer] = None) -> PFNModel: """Get an instance of the PFN model. Args: input_size (Tuple[None, int]): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (model_config.PFN): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: PFNModel: PFN model. """ return PFNModel( input_shape=input_size, Phi_sizes=args_model.Phi_sizes, F_sizes=args_model.F_sizes, Phi_backbone=args_model.Phi_backbone, activation=get_activation(args_model.activation), batch_norm=args_model.batch_norm, preprocess=preprocess, output_layer=output_layer)
def get_transformer_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: Transformer, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> TransformerModel
-
Get an instance of the transformer model.
Args
input_size
:Tuple[None, int]
- Input size of the model.
output_layer
:tf.keras.layers.Layer
- Output layer of the model.
args_model
:model_config.Transformer
- Model configuration.
preprocess
:Optional[tf.keras.layers.Layer]
, optional- Preprocessing layer. Defaults to None.
Returns
TransformerModel
- Transformer model.
Expand source code
def get_transformer_model(input_size: Tuple[None, int], output_layer: tf.keras.layers.Layer, args_model: model_config.Transformer, preprocess: Optional[tf.keras.layers.Layer] = None) -> TransformerModel: """Get an instance of the transformer model. Args: input_size (Tuple[None, int]): Input size of the model. output_layer (tf.keras.layers.Layer): Output layer of the model. args_model (model_config.Transformer): Model configuration. preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. Returns: TransformerModel: Transformer model. """ return TransformerModel( embed_dim=args_model.embed_dim, embed_layers=args_model.embed_layers, dropout=args_model.dropout, expansion=args_model.expansion, heads=args_model.heads, self_attn_layers=args_model.self_attn_layers, activation=get_activation(args_model.activation), input_shape=input_size, output_layer=output_layer, preprocess=preprocess)
def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt']) ‑> Callable[[Any, keras.engine.base_layer.Layer, Model, Optional[keras.engine.base_layer.Layer]], keras.engine.training.Model]
-
Get a model getter function.
Args
model_name
:str
- Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'.
Returns
Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]
- Model getter function.
Expand source code
def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'] ) -> Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]: """Get a model getter function. Args: model_name (str): Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'. Returns: Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]: Model getter function. """ lookup_model = {'fc': get_FC_model, 'highway': get_highway_model, 'pfn': get_pfn_model, 'efn': get_efn_model, 'transformer': get_transformer_model, 'part': get_part_model, 'depart': get_depart_model, 'bdt': get_bdt_model, } if model_name not in lookup_model: raise ValueError(f'Unknown model {model_name}') return lookup_model[model_name]