
Module for reading ROOT files and converting them to Tensorflow tf.RaggedTensor or tf.Tensor objects. The module contains the ROOTDataset class which is a wrapper of It's main purpose is to read ROOT files and convert them to Tensorflow tf.RaggedTensor or tf.Tensor objects, and to a object afterwards. It relies on the uproot package.

Two ooptinal backends are available for converting ROOT files to Tensorflow objects: pandas and awkward.

Module for reading ROOT files and converting them to Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects. 
The module contains the `ROOTDataset` class which is a wrapper of ``.
It's main purpose is to read ROOT files and convert them to Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects,
and to a `` object afterwards. It relies on the `uproot` package.

Two ooptinal backends are available for converting ROOT files to Tensorflow objects: `pandas` and `awkward`.
from __future__ import annotations
import tensorflow as tf
from typing import Callable, List, Union, Dict, Optional, Literal
import pickle
import os
import uproot
import logging
import pandas as pd
import awkward as ak

ROOTVariables = Dict[str, Union[tf.RaggedTensor, tf.Tensor]]
"""Type alias for a dictionary of ROOT variables. The keys are the variable  names and the values are the corresponding 
Tensorflow `tf.RaggedTensor` or `tf.Tensor`.

variables = {
    'jets_pt': tf.RaggedTensor([[1, 2, 3, 4, 5], [2, 3]], dtype=tf.float32),
    'eventNumber': tf.Tensor([1, 2], dtype=tf.int32),

def pandas_to_tensor(df: pd.Series) -> Union[tf.RaggedTensor, tf.Tensor]:
    """Converts a pandas `pd.Series` to a Tensorflow `tf.RaggedTensor` or `tf.Tensor`. The output is a `tf.RaggedTensor`
    if the Series has a multiple level index, otherwise it is a `tf.Tensor`. The number of levels of the index gives the
    number of dimensions of the output. 

        df (pd.Series): pandas pd.Series to be converted. Can have a single or multiple level index (`pd.MultiIndex`).

        tf.RaggedTensor or tf.Tensor: `tf.RaggedTensor` if df has number of index levels  greater than 1, else `tf.Tensor`.
    levels = df.index.nlevels
    if levels == 1:
        return tf.constant(df.values)
    elif levels == 2:
        row_lengths = df.groupby(level=[0]).count()
        return tf.RaggedTensor.from_row_lengths(df.values, row_lengths.values, validate=False)
        max_level_group = list(range(levels - 1))
        nested_row_lengths = [df.groupby(level=max_level_group).count()]
        for i in range(1, levels - 1):
        return tf.RaggedTensor.from_nested_row_lengths(df.values, nested_row_lengths=nested_row_lengths[::-1], validate=False)

def awkward_to_tensor(array: ak.Array) -> Union[tf.RaggedTensor, tf.Tensor]:
    """Converts an awkward `ak.Array` to a Tensorflow `tf.RaggedTensor` or tf.Tensor. The output is a `tf.RaggedTensor` 
    if the array has a dimension greater than 1, otherwise it is a `tf.Tensor`. The number of dimensions of the array 
    gives the number of dimensions of the output.

        array (ak.Array): awkward ak.Array to be converted. Can have a single or multiple dimensions.

        tf.RaggedTensor or tf.Tensor: `tf.RaggedTensor` if the array dimension is greater than 1, else `tf.Tensor`.
    if array.ndim == 1:
        return tf.constant(array.to_list())
    elif array.ndim == 2:
        row_lengths = ak.num(array, axis=1).to_list()
        return tf.RaggedTensor.from_row_lengths(ak.flatten(array, axis=None).to_list(), row_lengths=row_lengths, validate=False)
        nested_row_lengths = [ak.flatten(ak.num(array, axis=ax), axis=None).to_list()
                              for ax in range(1, array.ndim)]
        return tf.RaggedTensor.from_nested_row_lengths(ak.flatten(
            array, axis=None).to_list(), nested_row_lengths=nested_row_lengths, validate=False)

def read_ttree(tree: uproot.TTree, backend: Literal['pd', 'ak'] = 'pd', downcast: bool = True) -> ROOTVariables:
    """Reads a ROOT TTree and returns a dictionary of Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects. The keys are 
    the variable names and the values read from the TTree. Converting the TTree is done by a variable at a time. 

        tree (uproot.TTree): ROOT TTree to be read.
        backend (str, optional): 'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.
        downcast (bool, optional): Downcast the output to `tf.float32`, `tf.int32` or `tf.uint32`. Defaults to True.

        ValueError: If backend is not 'pd' or 'ak'.

        ROOTVariables: Dictionary of Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects. The keys are the variable names and the values read from the TTree.

    if backend != 'pd' and backend != 'ak':
        raise ValueError(
            f'Backend {backend} not supported. Choose from pd (pandas) or ak (awkward).')
    variables = tree.keys()
    output = {}
    for var in variables:
        var_branch = tree[var].array(library="ak")
        if ak.num(ak.flatten(var_branch, axis=None), axis=0) == 0:
        if backend == 'ak':
            tensor = awkward_to_tensor(var_branch)
        elif backend == 'pd':
            var_branch = ak.to_dataframe(var_branch)
            if var_branch.empty:
            tensor = pandas_to_tensor(var_branch['values'])

        if downcast:
            if tensor.dtype == tf.float64:
                tensor = tf.cast(tensor, tf.float32)
            elif tensor.dtype == tf.int64:
                tensor = tf.cast(tensor, tf.int32)
            elif tensor.dtype == tf.uint64:
                tensor = tf.cast(tensor, tf.uint32)

        output[var] = tensor'{var}: {output[var].shape} {output[var].dtype}')
    return output

class ROOTDataset:
    """Class to read a ROOT file and return a `` object. The dataset contains a dictionary of Tensorflow
    `tf.RaggedTensor` or `tf.Tensor` objects. The keys are the variable names and the values read from the TTree.
    The `.root` files are read using `uproot` and the `` is created using ``. 

    The ROOT file is read by a variable at a time, so the memory consumption may be high for large files. More precisely,
    the proces of creating the `` from a dictionary of `tf.RaggedTensor` or `tf.Tensor` objects 
    **consumes a lot of memory**. This is done as a trade-off for higher conversion speed. 

    import tensorflow as tf

    root_file = 'path/to/file.root'
    save_path = 'path/to/save/dataset'
    root_dataset = ROOTDataset.from_root_file(root_file)
    root_dataset = ROOTDataset.load(save_path)
    dataset = root_dataset.dataset
    # Use as a training dataset 

    The initialization is only a convenience method. The `ROOTDataset.from_root_file` or `ROOTDataset.from_root_files`
    methods should be used for creating a `ROOTDataset` object instead.
        dataset ( Tensorflow `` object.
        variables (list[str]): List of variable names.


    def __init__(self, dataset:, variables: List[str]):
        self._variables = variables
        self._dataset = dataset

    def variables(self) -> List[str]:
        """List of variable names inferred from the ROOT file."""
        return self._variables

    def dataset(self) ->
        """Tensorflow `` object created from the ROOT file."""
        return self._dataset

    def from_root_file(cls, filename: str,
                       tree_name: str = 'NOMINAL',
                       metadata_hist: Optional[str] = 'h_metadata',
                       backend: Literal['pd', 'ak'] = 'pd') -> ROOTDataset:
        """Reads a ROOT file and returns a `ROOTDataset` object. 

            filename (str): Path to the ROOT file.
            tree_name (str, optional): Name of the TTree in the ROOT file. Defaults to 'NOMINAL'.
            metadata_hist (str, optional): Name of the histogram containing the metadata. Defaults to 'h_metadata'. Could be `None`.
            backend (str, optional): 'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.

            ROOTDataset: `ROOTDataset` object.
        file =, object_cache=None, array_cache=None)
        tree = file[tree_name]"Loading ROOT file {filename}")
        sample = read_ttree(tree, backend=backend)

        if metadata_hist is not None:
  "Getting metadata")
            metadata = file[metadata_hist].values()
            sample['metadata'] = tf.tile(tf.constant(metadata)[tf.newaxis, :], [
                                         sample['eventNumber'].shape[0], 1])'Done loading file:{filename}')
        dataset =
        return cls(dataset, list(sample.keys()))

    def concat(cls, datasets: List[ROOTDataset]) -> ROOTDataset:
        """Concatenates a list of `ROOTDataset` objects. Data samples are sequentially concatenated using ``.

            datasets (list[ROOTDataset]): List of `ROOTDataset` objects.

            ValueError: If the variables of the datasets do not match.

            ROOTDataset: Combined `ROOTDataset` object.
        for dataset in datasets:
            if dataset.variables != datasets[0].variables:
                raise ValueError("Variables of datasets do not match")
        final_dataset = datasets[0]._dataset
        for ds in datasets[1:]:
            final_dataset = final_dataset.concatenate(ds._dataset)
        return cls(final_dataset, datasets[0].variables)

    def from_root_files(cls, filenames: Union[List[str], str]) -> ROOTDataset:
        """Reads a list of ROOT files and returns a `ROOTDataset` object. Can also be used to read a single file.

            filenames (list[str] or str): List of paths to the ROOT files or a single path to a ROOT file.

            ROOTDataset: `ROOTDataset` object.
        if isinstance(filenames, str):
            filenames = [filenames]
        return cls.concat([cls.from_root_file(filename) for filename in filenames])

    def load(cls, file: str, element_spec_path: Optional[str] = None) -> ROOTDataset:
        """Loads a `ROOTDataset` object from a saved directory. The saved object is a `` object
        saved using ``. The `element_spec` is loaded separately as a pickle object and is used 
        to create the `` object. Defaults to `element_spec` file inside the saved directory.
        Optionally, the `element_spec_path` can be passed as an argument as full path.

        Example of creating a `ROOTDataset` object from saved `` object.
        import pickle
        import tensorflow as tf

        dataset ={'a': [1, 2, 3], 'b': [4, 5, 6]})
        with open(os.path.join(file_path, 'element_spec'), 'wb') as f:
            pickle.dump(dataset.element_spec, f)

        root_dataset = ROOTDataset.load(file_path)


            file (str): Path to the saved directory.
            element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.

            ROOTDataset: `ROOTDataset` object.

        element_spec_path = os.path.join(
            file, 'element_spec') if element_spec_path is None else element_spec_path
        with open(element_spec_path, 'rb') as f:
            element_spec = pickle.load(f)
        dataset =
            file, compression='GZIP', element_spec=element_spec)
        return cls(dataset, list(element_spec.keys()))

    def save(self, save_path: str, element_spec_path: Optional[str] = None, shard_func: Optional[Callable[[ROOTVariables], tf.Tensor]] = None) -> None:
        """Saves a `ROOTDataset` object to a directory. The saved object is a `` object 
        and the `element_spec` is saved separately as a pickle object saved inside the saved directory.

            save_path (str): Path to the directory where the object is to be saved.
            element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.
            shard_func (Callable, optional): Function to shard the dataset. Used as a `shard_func` argument in ``. Defaults to `None`.


        element_spec_path = os.path.join(
            save_path, 'element_spec') if element_spec_path is None else element_spec_path
        element_spec = self._dataset.element_spec, compression='GZIP', shard_func=shard_func)
        with open(element_spec_path, 'wb') as f:
            pickle.dump(element_spec, f)

    def map(self, func: Callable[[ROOTVariables], ROOTVariables]) -> ROOTDataset:
        """Maps a function to the dataset. The function should take a `ROOTVariables` object as input and return a `ROOTVariables` object as output.

            func (Callable): Function to be mapped.
            num_parallel_calls (int, optional): Number of parallel calls to use. Defaults to `None`.

            ROOTDataset: Mapped `ROOTDataset` object.
        new_ds =
        new_ds = new_ds.prefetch(
        return ROOTDataset(new_ds, list(new_ds.element_spec.keys()))

Global variables

var ROOTVariables

Type alias for a dictionary of ROOT variables. The keys are the variable names and the values are the corresponding Tensorflow tf.RaggedTensor or tf.Tensor.


variables = {
    'jets_pt': tf.RaggedTensor([[1, 2, 3, 4, 5], [2, 3]], dtype=tf.float32),
    'eventNumber': tf.Tensor([1, 2], dtype=tf.int32),


def awkward_to_tensor(array: ak.Array) ‑> Union[tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor, tensorflow.python.framework.ops.Tensor]

Converts an awkward ak.Array to a Tensorflow tf.RaggedTensor or tf.Tensor. The output is a tf.RaggedTensor if the array has a dimension greater than 1, otherwise it is a tf.Tensor. The number of dimensions of the array gives the number of dimensions of the output.


array : ak.Array
awkward ak.Array to be converted. Can have a single or multiple dimensions.


tf.RaggedTensor or tf.Tensor
tf.RaggedTensor if the array dimension is greater than 1, else tf.Tensor.
def awkward_to_tensor(array: ak.Array) -> Union[tf.RaggedTensor, tf.Tensor]:
    """Converts an awkward `ak.Array` to a Tensorflow `tf.RaggedTensor` or tf.Tensor. The output is a `tf.RaggedTensor` 
    if the array has a dimension greater than 1, otherwise it is a `tf.Tensor`. The number of dimensions of the array 
    gives the number of dimensions of the output.

        array (ak.Array): awkward ak.Array to be converted. Can have a single or multiple dimensions.

        tf.RaggedTensor or tf.Tensor: `tf.RaggedTensor` if the array dimension is greater than 1, else `tf.Tensor`.
    if array.ndim == 1:
        return tf.constant(array.to_list())
    elif array.ndim == 2:
        row_lengths = ak.num(array, axis=1).to_list()
        return tf.RaggedTensor.from_row_lengths(ak.flatten(array, axis=None).to_list(), row_lengths=row_lengths, validate=False)
        nested_row_lengths = [ak.flatten(ak.num(array, axis=ax), axis=None).to_list()
                              for ax in range(1, array.ndim)]
        return tf.RaggedTensor.from_nested_row_lengths(ak.flatten(
            array, axis=None).to_list(), nested_row_lengths=nested_row_lengths, validate=False)
def pandas_to_tensor(df: pd.Series) ‑> Union[tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor, tensorflow.python.framework.ops.Tensor]

Converts a pandas pd.Series to a Tensorflow tf.RaggedTensor or tf.Tensor. The output is a tf.RaggedTensor if the Series has a multiple level index, otherwise it is a tf.Tensor. The number of levels of the index gives the number of dimensions of the output.


df : pd.Series
pandas pd.Series to be converted. Can have a single or multiple level index (pd.MultiIndex).


tf.RaggedTensor or tf.Tensor
tf.RaggedTensor if df has number of index levels greater than 1, else tf.Tensor.
def pandas_to_tensor(df: pd.Series) -> Union[tf.RaggedTensor, tf.Tensor]:
    """Converts a pandas `pd.Series` to a Tensorflow `tf.RaggedTensor` or `tf.Tensor`. The output is a `tf.RaggedTensor`
    if the Series has a multiple level index, otherwise it is a `tf.Tensor`. The number of levels of the index gives the
    number of dimensions of the output. 

        df (pd.Series): pandas pd.Series to be converted. Can have a single or multiple level index (`pd.MultiIndex`).

        tf.RaggedTensor or tf.Tensor: `tf.RaggedTensor` if df has number of index levels  greater than 1, else `tf.Tensor`.
    levels = df.index.nlevels
    if levels == 1:
        return tf.constant(df.values)
    elif levels == 2:
        row_lengths = df.groupby(level=[0]).count()
        return tf.RaggedTensor.from_row_lengths(df.values, row_lengths.values, validate=False)
        max_level_group = list(range(levels - 1))
        nested_row_lengths = [df.groupby(level=max_level_group).count()]
        for i in range(1, levels - 1):
        return tf.RaggedTensor.from_nested_row_lengths(df.values, nested_row_lengths=nested_row_lengths[::-1], validate=False)
def read_ttree(tree: uproot.TTree, backend: "Literal['pd', 'ak']" = 'pd', downcast: bool = True) ‑> Dict[str, Union[tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor, tensorflow.python.framework.ops.Tensor]]

Reads a ROOT TTree and returns a dictionary of Tensorflow tf.RaggedTensor or tf.Tensor objects. The keys are the variable names and the values read from the TTree. Converting the TTree is done by a variable at a time.


tree : uproot.TTree
ROOT TTree to be read.
backend : str, optional
'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.
downcast : bool, optional
Downcast the output to tf.float32, tf.int32 or tf.uint32. Defaults to True.


If backend is not 'pd' or 'ak'.


Dictionary of Tensorflow tf.RaggedTensor or tf.Tensor objects. The keys are the variable names and the values read from the TTree.
def read_ttree(tree: uproot.TTree, backend: Literal['pd', 'ak'] = 'pd', downcast: bool = True) -> ROOTVariables:
    """Reads a ROOT TTree and returns a dictionary of Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects. The keys are 
    the variable names and the values read from the TTree. Converting the TTree is done by a variable at a time. 

        tree (uproot.TTree): ROOT TTree to be read.
        backend (str, optional): 'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.
        downcast (bool, optional): Downcast the output to `tf.float32`, `tf.int32` or `tf.uint32`. Defaults to True.

        ValueError: If backend is not 'pd' or 'ak'.

        ROOTVariables: Dictionary of Tensorflow `tf.RaggedTensor` or `tf.Tensor` objects. The keys are the variable names and the values read from the TTree.

    if backend != 'pd' and backend != 'ak':
        raise ValueError(
            f'Backend {backend} not supported. Choose from pd (pandas) or ak (awkward).')
    variables = tree.keys()
    output = {}
    for var in variables:
        var_branch = tree[var].array(library="ak")
        if ak.num(ak.flatten(var_branch, axis=None), axis=0) == 0:
        if backend == 'ak':
            tensor = awkward_to_tensor(var_branch)
        elif backend == 'pd':
            var_branch = ak.to_dataframe(var_branch)
            if var_branch.empty:
            tensor = pandas_to_tensor(var_branch['values'])

        if downcast:
            if tensor.dtype == tf.float64:
                tensor = tf.cast(tensor, tf.float32)
            elif tensor.dtype == tf.int64:
                tensor = tf.cast(tensor, tf.int32)
            elif tensor.dtype == tf.uint64:
                tensor = tf.cast(tensor, tf.uint32)

        output[var] = tensor'{var}: {output[var].shape} {output[var].dtype}')
    return output


class ROOTDataset (dataset:, variables: List[str])

Class to read a ROOT file and return a object. The dataset contains a dictionary of Tensorflow tf.RaggedTensor or tf.Tensor objects. The keys are the variable names and the values read from the TTree. The .root files are read using uproot and the is created using

The ROOT file is read by a variable at a time, so the memory consumption may be high for large files. More precisely, the proces of creating the from a dictionary of tf.RaggedTensor or tf.Tensor objects consumes a lot of memory. This is done as a trade-off for higher conversion speed.


import tensorflow as tf

root_file = 'path/to/file.root'
save_path = 'path/to/save/dataset'
root_dataset = ROOTDataset.from_root_file(root_file)
root_dataset = ROOTDataset.load(save_path)
dataset = root_dataset.dataset
# Use as a training dataset 

The initialization is only a convenience method. The ROOTDataset.from_root_file() or ROOTDataset.from_root_files() methods should be used for creating a ROOTDataset object instead.


dataset :
Tensorflow object.
variables : list[str]
List of variable names.
class ROOTDataset:
    """Class to read a ROOT file and return a `` object. The dataset contains a dictionary of Tensorflow
    `tf.RaggedTensor` or `tf.Tensor` objects. The keys are the variable names and the values read from the TTree.
    The `.root` files are read using `uproot` and the `` is created using ``. 

    The ROOT file is read by a variable at a time, so the memory consumption may be high for large files. More precisely,
    the proces of creating the `` from a dictionary of `tf.RaggedTensor` or `tf.Tensor` objects 
    **consumes a lot of memory**. This is done as a trade-off for higher conversion speed. 

    import tensorflow as tf

    root_file = 'path/to/file.root'
    save_path = 'path/to/save/dataset'
    root_dataset = ROOTDataset.from_root_file(root_file)
    root_dataset = ROOTDataset.load(save_path)
    dataset = root_dataset.dataset
    # Use as a training dataset 

    The initialization is only a convenience method. The `ROOTDataset.from_root_file` or `ROOTDataset.from_root_files`
    methods should be used for creating a `ROOTDataset` object instead.
        dataset ( Tensorflow `` object.
        variables (list[str]): List of variable names.


    def __init__(self, dataset:, variables: List[str]):
        self._variables = variables
        self._dataset = dataset

    def variables(self) -> List[str]:
        """List of variable names inferred from the ROOT file."""
        return self._variables

    def dataset(self) ->
        """Tensorflow `` object created from the ROOT file."""
        return self._dataset

    def from_root_file(cls, filename: str,
                       tree_name: str = 'NOMINAL',
                       metadata_hist: Optional[str] = 'h_metadata',
                       backend: Literal['pd', 'ak'] = 'pd') -> ROOTDataset:
        """Reads a ROOT file and returns a `ROOTDataset` object. 

            filename (str): Path to the ROOT file.
            tree_name (str, optional): Name of the TTree in the ROOT file. Defaults to 'NOMINAL'.
            metadata_hist (str, optional): Name of the histogram containing the metadata. Defaults to 'h_metadata'. Could be `None`.
            backend (str, optional): 'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.

            ROOTDataset: `ROOTDataset` object.
        file =, object_cache=None, array_cache=None)
        tree = file[tree_name]"Loading ROOT file {filename}")
        sample = read_ttree(tree, backend=backend)

        if metadata_hist is not None:
  "Getting metadata")
            metadata = file[metadata_hist].values()
            sample['metadata'] = tf.tile(tf.constant(metadata)[tf.newaxis, :], [
                                         sample['eventNumber'].shape[0], 1])'Done loading file:{filename}')
        dataset =
        return cls(dataset, list(sample.keys()))

    def concat(cls, datasets: List[ROOTDataset]) -> ROOTDataset:
        """Concatenates a list of `ROOTDataset` objects. Data samples are sequentially concatenated using ``.

            datasets (list[ROOTDataset]): List of `ROOTDataset` objects.

            ValueError: If the variables of the datasets do not match.

            ROOTDataset: Combined `ROOTDataset` object.
        for dataset in datasets:
            if dataset.variables != datasets[0].variables:
                raise ValueError("Variables of datasets do not match")
        final_dataset = datasets[0]._dataset
        for ds in datasets[1:]:
            final_dataset = final_dataset.concatenate(ds._dataset)
        return cls(final_dataset, datasets[0].variables)

    def from_root_files(cls, filenames: Union[List[str], str]) -> ROOTDataset:
        """Reads a list of ROOT files and returns a `ROOTDataset` object. Can also be used to read a single file.

            filenames (list[str] or str): List of paths to the ROOT files or a single path to a ROOT file.

            ROOTDataset: `ROOTDataset` object.
        if isinstance(filenames, str):
            filenames = [filenames]
        return cls.concat([cls.from_root_file(filename) for filename in filenames])

    def load(cls, file: str, element_spec_path: Optional[str] = None) -> ROOTDataset:
        """Loads a `ROOTDataset` object from a saved directory. The saved object is a `` object
        saved using ``. The `element_spec` is loaded separately as a pickle object and is used 
        to create the `` object. Defaults to `element_spec` file inside the saved directory.
        Optionally, the `element_spec_path` can be passed as an argument as full path.

        Example of creating a `ROOTDataset` object from saved `` object.
        import pickle
        import tensorflow as tf

        dataset ={'a': [1, 2, 3], 'b': [4, 5, 6]})
        with open(os.path.join(file_path, 'element_spec'), 'wb') as f:
            pickle.dump(dataset.element_spec, f)

        root_dataset = ROOTDataset.load(file_path)


            file (str): Path to the saved directory.
            element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.

            ROOTDataset: `ROOTDataset` object.

        element_spec_path = os.path.join(
            file, 'element_spec') if element_spec_path is None else element_spec_path
        with open(element_spec_path, 'rb') as f:
            element_spec = pickle.load(f)
        dataset =
            file, compression='GZIP', element_spec=element_spec)
        return cls(dataset, list(element_spec.keys()))

    def save(self, save_path: str, element_spec_path: Optional[str] = None, shard_func: Optional[Callable[[ROOTVariables], tf.Tensor]] = None) -> None:
        """Saves a `ROOTDataset` object to a directory. The saved object is a `` object 
        and the `element_spec` is saved separately as a pickle object saved inside the saved directory.

            save_path (str): Path to the directory where the object is to be saved.
            element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.
            shard_func (Callable, optional): Function to shard the dataset. Used as a `shard_func` argument in ``. Defaults to `None`.


        element_spec_path = os.path.join(
            save_path, 'element_spec') if element_spec_path is None else element_spec_path
        element_spec = self._dataset.element_spec, compression='GZIP', shard_func=shard_func)
        with open(element_spec_path, 'wb') as f:
            pickle.dump(element_spec, f)

    def map(self, func: Callable[[ROOTVariables], ROOTVariables]) -> ROOTDataset:
        """Maps a function to the dataset. The function should take a `ROOTVariables` object as input and return a `ROOTVariables` object as output.

            func (Callable): Function to be mapped.
            num_parallel_calls (int, optional): Number of parallel calls to use. Defaults to `None`.

            ROOTDataset: Mapped `ROOTDataset` object.
        new_ds =
        new_ds = new_ds.prefetch(
        return ROOTDataset(new_ds, list(new_ds.element_spec.keys()))

Static methods

def concat(datasets: List[ROOTDataset]) ‑> ROOTDataset

Concatenates a list of ROOTDataset objects. Data samples are sequentially concatenated using


datasets : list[ROOTDataset]
List of ROOTDataset objects.


If the variables of the datasets do not match.


Combined ROOTDataset object.
def concat(cls, datasets: List[ROOTDataset]) -> ROOTDataset:
    """Concatenates a list of `ROOTDataset` objects. Data samples are sequentially concatenated using ``.

        datasets (list[ROOTDataset]): List of `ROOTDataset` objects.

        ValueError: If the variables of the datasets do not match.

        ROOTDataset: Combined `ROOTDataset` object.
    for dataset in datasets:
        if dataset.variables != datasets[0].variables:
            raise ValueError("Variables of datasets do not match")
    final_dataset = datasets[0]._dataset
    for ds in datasets[1:]:
        final_dataset = final_dataset.concatenate(ds._dataset)
    return cls(final_dataset, datasets[0].variables)
def from_root_file(filename: str, tree_name: str = 'NOMINAL', metadata_hist: Optional[str] = 'h_metadata', backend: "Literal['pd', 'ak']" = 'pd') ‑> ROOTDataset

Reads a ROOT file and returns a ROOTDataset object.


filename : str
Path to the ROOT file.
tree_name : str, optional
Name of the TTree in the ROOT file. Defaults to 'NOMINAL'.
metadata_hist : str, optional
Name of the histogram containing the metadata. Defaults to 'h_metadata'. Could be None.
backend : str, optional
'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.


ROOTDataset object.
def from_root_file(cls, filename: str,
                   tree_name: str = 'NOMINAL',
                   metadata_hist: Optional[str] = 'h_metadata',
                   backend: Literal['pd', 'ak'] = 'pd') -> ROOTDataset:
    """Reads a ROOT file and returns a `ROOTDataset` object. 

        filename (str): Path to the ROOT file.
        tree_name (str, optional): Name of the TTree in the ROOT file. Defaults to 'NOMINAL'.
        metadata_hist (str, optional): Name of the histogram containing the metadata. Defaults to 'h_metadata'. Could be `None`.
        backend (str, optional): 'pd' or 'ak'. Backend to use for reading the TTree, 'pd' is faster but consumes more memory. Defaults to 'pd'.

        ROOTDataset: `ROOTDataset` object.
    file =, object_cache=None, array_cache=None)
    tree = file[tree_name]"Loading ROOT file {filename}")
    sample = read_ttree(tree, backend=backend)

    if metadata_hist is not None:"Getting metadata")
        metadata = file[metadata_hist].values()
        sample['metadata'] = tf.tile(tf.constant(metadata)[tf.newaxis, :], [
                                     sample['eventNumber'].shape[0], 1])'Done loading file:{filename}')
    dataset =
    return cls(dataset, list(sample.keys()))
def from_root_files(filenames: Union[List[str], str]) ‑> ROOTDataset

Reads a list of ROOT files and returns a ROOTDataset object. Can also be used to read a single file.


filenames : list[str] or str
List of paths to the ROOT files or a single path to a ROOT file.


ROOTDataset object.
def from_root_files(cls, filenames: Union[List[str], str]) -> ROOTDataset:
    """Reads a list of ROOT files and returns a `ROOTDataset` object. Can also be used to read a single file.

        filenames (list[str] or str): List of paths to the ROOT files or a single path to a ROOT file.

        ROOTDataset: `ROOTDataset` object.
    if isinstance(filenames, str):
        filenames = [filenames]
    return cls.concat([cls.from_root_file(filename) for filename in filenames])
def load(file: str, element_spec_path: Optional[str] = None) ‑> ROOTDataset

Loads a ROOTDataset object from a saved directory. The saved object is a object saved using The element_spec is loaded separately as a pickle object and is used to create the object. Defaults to element_spec file inside the saved directory. Optionally, the element_spec_path can be passed as an argument as full path.

Example: Example of creating a ROOTDataset object from saved object.

import pickle
import tensorflow as tf

dataset ={'a': [1, 2, 3], 'b': [4, 5, 6]})
with open(os.path.join(file_path, 'element_spec'), 'wb') as f:
    pickle.dump(dataset.element_spec, f)

root_dataset = ROOTDataset.load(file_path)


file : str
Path to the saved directory.
element_spec_path : str, optional
Path to the saved element_spec as a pickle file. Defaults to element_spec file inside the saved directory.


ROOTDataset object.
def load(cls, file: str, element_spec_path: Optional[str] = None) -> ROOTDataset:
    """Loads a `ROOTDataset` object from a saved directory. The saved object is a `` object
    saved using ``. The `element_spec` is loaded separately as a pickle object and is used 
    to create the `` object. Defaults to `element_spec` file inside the saved directory.
    Optionally, the `element_spec_path` can be passed as an argument as full path.

    Example of creating a `ROOTDataset` object from saved `` object.
    import pickle
    import tensorflow as tf

    dataset ={'a': [1, 2, 3], 'b': [4, 5, 6]})
    with open(os.path.join(file_path, 'element_spec'), 'wb') as f:
        pickle.dump(dataset.element_spec, f)

    root_dataset = ROOTDataset.load(file_path)


        file (str): Path to the saved directory.
        element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.

        ROOTDataset: `ROOTDataset` object.

    element_spec_path = os.path.join(
        file, 'element_spec') if element_spec_path is None else element_spec_path
    with open(element_spec_path, 'rb') as f:
        element_spec = pickle.load(f)
    dataset =
        file, compression='GZIP', element_spec=element_spec)
    return cls(dataset, list(element_spec.keys()))

Instance variables

var dataset :

Tensorflow object created from the ROOT file.

def dataset(self) ->
    """Tensorflow `` object created from the ROOT file."""
    return self._dataset
var variables : List[str]

List of variable names inferred from the ROOT file.

def variables(self) -> List[str]:
    """List of variable names inferred from the ROOT file."""
    return self._variables


Methods

Maps a function to the dataset. The function should take a ROOTVariables object as input and return a ROOTVariables object as output.


func : Callable
Function to be mapped.
num_parallel_calls : int, optional
Number of parallel calls to use. Defaults to None.


Mapped ROOTDataset object.
def map(self, func: Callable[[ROOTVariables], ROOTVariables]) -> ROOTDataset:
    """Maps a function to the dataset. The function should take a `ROOTVariables` object as input and return a `ROOTVariables` object as output.

        func (Callable): Function to be mapped.
        num_parallel_calls (int, optional): Number of parallel calls to use. Defaults to `None`.

        ROOTDataset: Mapped `ROOTDataset` object.
    new_ds =
    new_ds = new_ds.prefetch(
    return ROOTDataset(new_ds, list(new_ds.element_spec.keys()))
def save(self, save_path: str, element_spec_path: Optional[str] = None, shard_func: Optional[Callable[[ROOTVariables], tf.Tensor]] = None) ‑> None

Saves a ROOTDataset object to a directory. The saved object is a object and the element_spec is saved separately as a pickle object saved inside the saved directory.


save_path : str
Path to the directory where the object is to be saved.
element_spec_path : str, optional
Path to the saved element_spec as a pickle file. Defaults to element_spec file inside the saved directory.
shard_func : Callable, optional
Function to shard the dataset. Used as a shard_func argument in Defaults to None.



def save(self, save_path: str, element_spec_path: Optional[str] = None, shard_func: Optional[Callable[[ROOTVariables], tf.Tensor]] = None) -> None:
    """Saves a `ROOTDataset` object to a directory. The saved object is a `` object 
    and the `element_spec` is saved separately as a pickle object saved inside the saved directory.

        save_path (str): Path to the directory where the object is to be saved.
        element_spec_path (str, optional): Path to the saved `element_spec` as a pickle file. Defaults to `element_spec` file inside the saved directory.
        shard_func (Callable, optional): Function to shard the dataset. Used as a `shard_func` argument in ``. Defaults to `None`.


    element_spec_path = os.path.join(
        save_path, 'element_spec') if element_spec_path is None else element_spec_path
    element_spec = self._dataset.element_spec, compression='GZIP', shard_func=shard_func)
    with open(element_spec_path, 'wb') as f:
        pickle.dump(element_spec, f)