Module jidenn.models.BDT
BDT model based on tensorflow_decision_forests
implementation of tfdf.keras.GradientBoostedTreesModel
.
See https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel for more details.
Expand source code
"""
BDT model based on `tensorflow_decision_forests` implementation of `tfdf.keras.GradientBoostedTreesModel`.
See https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel for more details.
"""
import tensorflow_decision_forests as tfdf
from jidenn.config import model_config
def bdt_model(args_model: model_config.BDT) -> tfdf.keras.GradientBoostedTreesModel:
"""Builds a BDT model
Args:
args_model (cfg_BDT): BDT model config
Returns:
tfdf.keras.RandomForestModel: BDT model
"""
model = tfdf.keras.GradientBoostedTreesModel(
num_trees=args_model.num_trees,
growing_strategy=args_model.growing_strategy,
max_depth=args_model.max_depth,
split_axis=args_model.split_axis,
early_stopping='NONE',
shrinkage=args_model.shrinkage,
min_examples=args_model.min_examples,
verbose=2,
num_threads=args_model.num_threads,
l2_regularization=args_model.l2_regularization,
loss='BINOMIAL_LOG_LIKELIHOOD',
max_num_nodes=args_model.max_num_nodes,
temp_directory=args_model.tmp_dir,
)
return model
Functions
def bdt_model(args_model: BDT) ‑> tensorflow_decision_forests.keras.GradientBoostedTreesModel
-
Builds a BDT model
Args
args_model
:cfg_BDT
- BDT model config
Returns
tfdf.keras.RandomForestModel
- BDT model
Expand source code
def bdt_model(args_model: model_config.BDT) -> tfdf.keras.GradientBoostedTreesModel: """Builds a BDT model Args: args_model (cfg_BDT): BDT model config Returns: tfdf.keras.RandomForestModel: BDT model """ model = tfdf.keras.GradientBoostedTreesModel( num_trees=args_model.num_trees, growing_strategy=args_model.growing_strategy, max_depth=args_model.max_depth, split_axis=args_model.split_axis, early_stopping='NONE', shrinkage=args_model.shrinkage, min_examples=args_model.min_examples, verbose=2, num_threads=args_model.num_threads, l2_regularization=args_model.l2_regularization, loss='BINOMIAL_LOG_LIKELIHOOD', max_num_nodes=args_model.max_num_nodes, temp_directory=args_model.tmp_dir, ) return model