Module jidenn.model_builders.model_initialization

Module for initializing models using the configuration file. The config options are explicitly passed to the model classes __init__ method. For each model, a function is defined that returns an instance of the model class. The mapping from string names of the models in the config file is then mapped to the corresponding function.

Expand source code
"""
Module for initializing models using the configuration file.
The config options are explicitly passed to the model classes `__init__` method.
For each model, a function is defined that returns an instance of the model class.
The mapping from string names of the models in the config file is then mapped to the corresponding function.
"""

import tensorflow as tf
import tensorflow_addons as tfa
from typing import Tuple, Optional, Callable, List, Literal, Union, Any

from jidenn.config import model_config

from ..models.FC import FCModel
from ..models.Highway import HighwayModel
from ..models.Transformer import TransformerModel
from ..models.ParT import ParTModel
from ..models.DeParT import DeParTModel
from ..models.PFN import PFNModel
from ..models.EFN import EFNModel
from ..models.BDT import bdt_model


def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) -> Callable[[tf.Tensor], tf.Tensor]:
    """Map string names of activation functions to the corresponding function.

    Args:
        activation (str): Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish'].
    Returns:
        Callable[[tf.Tensor], tf.Tensor]: Activation function.
    """
    if activation == 'relu':
        return tf.nn.relu
    elif activation == 'gelu':
        return tfa.activations.gelu
    elif activation == 'tanh':
        return tf.nn.tanh
    elif activation == 'swish':
        return tf.keras.activations.swish
    else:
        raise NotImplementedError(f'Activation {activation} not supported.')


def get_FC_model(input_size: int,
                 output_layer: tf.keras.layers.Layer,
                 args_model: model_config.FC,
                 preprocess: Optional[tf.keras.layers.Layer] = None) -> FCModel:
    """Get an instance of the fully-connected model.

    Args:
        input_size (int): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (jidenn.config.model_config.FC): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        FCModel: Fully-connected model.
    """

    return FCModel(
        layer_size=args_model.layer_size,
        num_layers=args_model.num_layers,
        dropout=args_model.dropout,
        input_size=input_size,
        output_layer=output_layer,
        activation=get_activation(args_model.activation),
        preprocess=preprocess)


def get_highway_model(input_size: int,
                      output_layer: tf.keras.layers.Layer,
                      args_model: model_config.Highway,
                      preprocess: Optional[tf.keras.layers.Layer] = None) -> HighwayModel:
    """Get an instance of the highway model.

    Args:
        input_size (int): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (jidenn.config.model_config.Highway): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        HighwayModel: Highway model.
    """

    return HighwayModel(
        layer_size=args_model.layer_size,
        num_layers=args_model.num_layers,
        dropout=args_model.dropout,
        input_size=input_size,
        output_layer=output_layer,
        activation=get_activation(args_model.activation),
        preprocess=preprocess)


def get_pfn_model(input_size: Tuple[None, int],
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.PFN,
                  preprocess: Optional[tf.keras.layers.Layer] = None) -> PFNModel:
    """Get an instance of the PFN model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.PFN): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        PFNModel: PFN model.
    """

    return PFNModel(
        input_shape=input_size,
        Phi_sizes=args_model.Phi_sizes,
        F_sizes=args_model.F_sizes,
        Phi_backbone=args_model.Phi_backbone,
        activation=get_activation(args_model.activation),
        batch_norm=args_model.batch_norm,
        preprocess=preprocess,
        output_layer=output_layer)


def get_efn_model(input_size: Tuple[None, int],
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.EFN,
                  preprocess: Optional[tf.keras.layers.Layer] = None) -> EFNModel:
    """Get an instance of the EFN model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.EFN): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.  

    Returns:
        EFNModel: EFN model.
    """

    return EFNModel(
        input_shape=input_size,
        Phi_sizes=args_model.Phi_sizes,
        F_sizes=args_model.F_sizes,
        Phi_backbone=args_model.Phi_backbone,
        activation=get_activation(args_model.activation),
        batch_norm=args_model.batch_norm,
        preprocess=preprocess,
        output_layer=output_layer)


def get_transformer_model(input_size: Tuple[None, int],
                          output_layer: tf.keras.layers.Layer,
                          args_model: model_config.Transformer,
                          preprocess: Optional[tf.keras.layers.Layer] = None) -> TransformerModel:
    """Get an instance of the transformer model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.Transformer): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.  

    Returns:
        TransformerModel: Transformer model.
    """

    return TransformerModel(
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        dropout=args_model.dropout,
        expansion=args_model.expansion,
        heads=args_model.heads,
        self_attn_layers=args_model.self_attn_layers,
        activation=get_activation(args_model.activation),
        input_shape=input_size,
        output_layer=output_layer,
        preprocess=preprocess)


def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
                   output_layer: tf.keras.layers.Layer,
                   args_model: model_config.ParT,
                   preprocess: Optional[tf.keras.layers.Layer] = None) -> ParTModel:
    """Get an instance of the ParT model.

    Args:
        input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.ParT): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        ParTModel: ParT model.
    """

    return ParTModel(
        input_shape=input_size,
        output_layer=output_layer,
        #
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        self_attn_layers=args_model.self_attn_layers,
        class_attn_layers=args_model.class_attn_layers,
        expansion=args_model.expansion,
        heads=args_model.heads,
        dropout=args_model.dropout,
        interaction_embed_layers=args_model.interaction_embedding_layers,
        interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
        #
        preprocess=preprocess,
        activation=get_activation(args_model.activation))


def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
                     output_layer: tf.keras.layers.Layer,
                     args_model: model_config.DeParT,
                     preprocess: Optional[tf.keras.layers.Layer] = None) -> DeParTModel:
    """Get an instance of the DeParT model.

    Args:
        input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.DeParT): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        DeParTModel: DeParT model.
    """

    return DeParTModel(
        input_shape=input_size,
        output_layer=output_layer,
        #
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        self_attn_layers=args_model.self_attn_layers,
        class_attn_layers=args_model.class_attn_layers,
        expansion=args_model.expansion,
        heads=args_model.heads,
        dropout=args_model.dropout,
        layer_scale_init_value=args_model.layer_scale_init_value,
        stochastic_depth_drop_rate=args_model.stochastic_depth_drop_rate,
        class_dropout=args_model.class_dropout,
        class_stochastic_depth_drop_rate=args_model.class_stochastic_depth_drop_rate,
        interaction_embed_layers=args_model.interaction_embedding_layers,
        interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
        #
        preprocess=preprocess,
        activation=get_activation(args_model.activation))


def get_bdt_model(input_size: int,
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.BDT,
                  preprocess: Optional[tf.keras.layers.Layer] = None):
    """Get an instance of the BDT model.

    Args:
        args_model (model_config.BDT): Model configuration.
        input_size (int): Input size of the model. UNUSED
        output_layer (tf.keras.layers.Layer): Output layer of the model. UNUSED
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. UNUSED

    Returns:
        tfdf.keras.GradientBoostedTreesModel: BDT model.
    """

    return bdt_model(args_model)


def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt']
                        ) -> Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]:
    """Get a model getter function.
    
    Args:
        model_name (str): Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'.
        
    Returns:
        Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]: Model getter function.
    """

    lookup_model = {'fc': get_FC_model,
                    'highway': get_highway_model,
                    'pfn': get_pfn_model,
                    'efn': get_efn_model,
                    'transformer': get_transformer_model,
                    'part': get_part_model,
                    'depart': get_depart_model,
                    'bdt': get_bdt_model, }

    if model_name not in lookup_model:
        raise ValueError(f'Unknown model {model_name}')

    return lookup_model[model_name]

Functions

def get_FC_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: FC, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> FCModel

Get an instance of the fully-connected model.

Args

input_size : int
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : FC
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

FCModel
Fully-connected model.
Expand source code
def get_FC_model(input_size: int,
                 output_layer: tf.keras.layers.Layer,
                 args_model: model_config.FC,
                 preprocess: Optional[tf.keras.layers.Layer] = None) -> FCModel:
    """Get an instance of the fully-connected model.

    Args:
        input_size (int): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (jidenn.config.model_config.FC): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        FCModel: Fully-connected model.
    """

    return FCModel(
        layer_size=args_model.layer_size,
        num_layers=args_model.num_layers,
        dropout=args_model.dropout,
        input_size=input_size,
        output_layer=output_layer,
        activation=get_activation(args_model.activation),
        preprocess=preprocess)
def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) ‑> Callable[[tensorflow.python.framework.ops.Tensor], tensorflow.python.framework.ops.Tensor]

Map string names of activation functions to the corresponding function.

Args

activation : str
Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish'].

Returns

Callable[[tf.Tensor], tf.Tensor]
Activation function.
Expand source code
def get_activation(activation: Literal['relu', 'gelu', 'tanh', 'swish']) -> Callable[[tf.Tensor], tf.Tensor]:
    """Map string names of activation functions to the corresponding function.

    Args:
        activation (str): Name of the activation function. One of ['relu', 'gelu', 'tanh', 'swish'].
    Returns:
        Callable[[tf.Tensor], tf.Tensor]: Activation function.
    """
    if activation == 'relu':
        return tf.nn.relu
    elif activation == 'gelu':
        return tfa.activations.gelu
    elif activation == 'tanh':
        return tf.nn.tanh
    elif activation == 'swish':
        return tf.keras.activations.swish
    else:
        raise NotImplementedError(f'Activation {activation} not supported.')
def get_bdt_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: BDT, preprocess: Optional[keras.engine.base_layer.Layer] = None)

Get an instance of the BDT model.

Args

args_model : model_config.BDT
Model configuration.
input_size : int
Input size of the model. UNUSED
output_layer : tf.keras.layers.Layer
Output layer of the model. UNUSED
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None. UNUSED

Returns

tfdf.keras.GradientBoostedTreesModel
BDT model.
Expand source code
def get_bdt_model(input_size: int,
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.BDT,
                  preprocess: Optional[tf.keras.layers.Layer] = None):
    """Get an instance of the BDT model.

    Args:
        args_model (model_config.BDT): Model configuration.
        input_size (int): Input size of the model. UNUSED
        output_layer (tf.keras.layers.Layer): Output layer of the model. UNUSED
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None. UNUSED

    Returns:
        tfdf.keras.GradientBoostedTreesModel: BDT model.
    """

    return bdt_model(args_model)
def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: keras.engine.base_layer.Layer, args_model: DeParT, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> DeParTModel

Get an instance of the DeParT model.

Args

input_size : Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : model_config.DeParT
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

DeParTModel
DeParT model.
Expand source code
def get_depart_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
                     output_layer: tf.keras.layers.Layer,
                     args_model: model_config.DeParT,
                     preprocess: Optional[tf.keras.layers.Layer] = None) -> DeParTModel:
    """Get an instance of the DeParT model.

    Args:
        input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.DeParT): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        DeParTModel: DeParT model.
    """

    return DeParTModel(
        input_shape=input_size,
        output_layer=output_layer,
        #
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        self_attn_layers=args_model.self_attn_layers,
        class_attn_layers=args_model.class_attn_layers,
        expansion=args_model.expansion,
        heads=args_model.heads,
        dropout=args_model.dropout,
        layer_scale_init_value=args_model.layer_scale_init_value,
        stochastic_depth_drop_rate=args_model.stochastic_depth_drop_rate,
        class_dropout=args_model.class_dropout,
        class_stochastic_depth_drop_rate=args_model.class_stochastic_depth_drop_rate,
        interaction_embed_layers=args_model.interaction_embedding_layers,
        interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
        #
        preprocess=preprocess,
        activation=get_activation(args_model.activation))
def get_efn_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: EFN, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> EFNModel

Get an instance of the EFN model.

Args

input_size : Tuple[None, int]
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : model_config.EFN
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

EFNModel
EFN model.
Expand source code
def get_efn_model(input_size: Tuple[None, int],
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.EFN,
                  preprocess: Optional[tf.keras.layers.Layer] = None) -> EFNModel:
    """Get an instance of the EFN model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.EFN): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.  

    Returns:
        EFNModel: EFN model.
    """

    return EFNModel(
        input_shape=input_size,
        Phi_sizes=args_model.Phi_sizes,
        F_sizes=args_model.F_sizes,
        Phi_backbone=args_model.Phi_backbone,
        activation=get_activation(args_model.activation),
        batch_norm=args_model.batch_norm,
        preprocess=preprocess,
        output_layer=output_layer)
def get_highway_model(input_size: int, output_layer: keras.engine.base_layer.Layer, args_model: Highway, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> HighwayModel

Get an instance of the highway model.

Args

input_size : int
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : Highway
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

HighwayModel
Highway model.
Expand source code
def get_highway_model(input_size: int,
                      output_layer: tf.keras.layers.Layer,
                      args_model: model_config.Highway,
                      preprocess: Optional[tf.keras.layers.Layer] = None) -> HighwayModel:
    """Get an instance of the highway model.

    Args:
        input_size (int): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (jidenn.config.model_config.Highway): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        HighwayModel: Highway model.
    """

    return HighwayModel(
        layer_size=args_model.layer_size,
        num_layers=args_model.num_layers,
        dropout=args_model.dropout,
        input_size=input_size,
        output_layer=output_layer,
        activation=get_activation(args_model.activation),
        preprocess=preprocess)
def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]], output_layer: keras.engine.base_layer.Layer, args_model: ParT, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> ParTModel

Get an instance of the ParT model.

Args

input_size : Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : model_config.ParT
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

ParTModel
ParT model.
Expand source code
def get_part_model(input_size: Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]],
                   output_layer: tf.keras.layers.Layer,
                   args_model: model_config.ParT,
                   preprocess: Optional[tf.keras.layers.Layer] = None) -> ParTModel:
    """Get an instance of the ParT model.

    Args:
        input_size (Union[Tuple[None, int], Tuple[Tuple[None, int], Tuple[None, None, int]]]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.ParT): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        ParTModel: ParT model.
    """

    return ParTModel(
        input_shape=input_size,
        output_layer=output_layer,
        #
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        self_attn_layers=args_model.self_attn_layers,
        class_attn_layers=args_model.class_attn_layers,
        expansion=args_model.expansion,
        heads=args_model.heads,
        dropout=args_model.dropout,
        interaction_embed_layers=args_model.interaction_embedding_layers,
        interaction_embed_layer_size=args_model.interaction_embedding_layer_size,
        #
        preprocess=preprocess,
        activation=get_activation(args_model.activation))
def get_pfn_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: PFN, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> PFNModel

Get an instance of the PFN model.

Args

input_size : Tuple[None, int]
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : model_config.PFN
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

PFNModel
PFN model.
Expand source code
def get_pfn_model(input_size: Tuple[None, int],
                  output_layer: tf.keras.layers.Layer,
                  args_model: model_config.PFN,
                  preprocess: Optional[tf.keras.layers.Layer] = None) -> PFNModel:
    """Get an instance of the PFN model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.PFN): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.

    Returns:
        PFNModel: PFN model.
    """

    return PFNModel(
        input_shape=input_size,
        Phi_sizes=args_model.Phi_sizes,
        F_sizes=args_model.F_sizes,
        Phi_backbone=args_model.Phi_backbone,
        activation=get_activation(args_model.activation),
        batch_norm=args_model.batch_norm,
        preprocess=preprocess,
        output_layer=output_layer)
def get_transformer_model(input_size: Tuple[None, int], output_layer: keras.engine.base_layer.Layer, args_model: Transformer, preprocess: Optional[keras.engine.base_layer.Layer] = None) ‑> TransformerModel

Get an instance of the transformer model.

Args

input_size : Tuple[None, int]
Input size of the model.
output_layer : tf.keras.layers.Layer
Output layer of the model.
args_model : model_config.Transformer
Model configuration.
preprocess : Optional[tf.keras.layers.Layer], optional
Preprocessing layer. Defaults to None.

Returns

TransformerModel
Transformer model.
Expand source code
def get_transformer_model(input_size: Tuple[None, int],
                          output_layer: tf.keras.layers.Layer,
                          args_model: model_config.Transformer,
                          preprocess: Optional[tf.keras.layers.Layer] = None) -> TransformerModel:
    """Get an instance of the transformer model.

    Args:
        input_size (Tuple[None, int]): Input size of the model.
        output_layer (tf.keras.layers.Layer): Output layer of the model.
        args_model (model_config.Transformer): Model configuration.
        preprocess (Optional[tf.keras.layers.Layer], optional): Preprocessing layer. Defaults to None.  

    Returns:
        TransformerModel: Transformer model.
    """

    return TransformerModel(
        embed_dim=args_model.embed_dim,
        embed_layers=args_model.embed_layers,
        dropout=args_model.dropout,
        expansion=args_model.expansion,
        heads=args_model.heads,
        self_attn_layers=args_model.self_attn_layers,
        activation=get_activation(args_model.activation),
        input_shape=input_size,
        output_layer=output_layer,
        preprocess=preprocess)
def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt']) ‑> Callable[[Any, keras.engine.base_layer.Layer, Model, Optional[keras.engine.base_layer.Layer]], keras.engine.training.Model]

Get a model getter function.

Args

model_name : str
Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'.

Returns

Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]
Model getter function.
Expand source code
def model_getter_lookup(model_name: Literal['fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt']
                        ) -> Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]:
    """Get a model getter function.
    
    Args:
        model_name (str): Name of the model. Options are 'fc', 'highway', 'pfn', 'efn', 'transformer', 'part', 'depart', 'bdt'.
        
    Returns:
        Callable[[Any, tf.keras.layers.Layer, model_config.Model, Optional[tf.keras.layers.Layer]], tf.keras.Model]: Model getter function.
    """

    lookup_model = {'fc': get_FC_model,
                    'highway': get_highway_model,
                    'pfn': get_pfn_model,
                    'efn': get_efn_model,
                    'transformer': get_transformer_model,
                    'part': get_part_model,
                    'depart': get_depart_model,
                    'bdt': get_bdt_model, }

    if model_name not in lookup_model:
        raise ValueError(f'Unknown model {model_name}')

    return lookup_model[model_name]