Module jidenn.preprocess.flatten_dataset

Set of functions to flatten a dataset. The flattening is done by the reference variable, which is assumed to be a tf.RaggedTensor. The shape of the reference variable is used to infer the shape of the other variables. Th

Expand source code
"""
Set of functions to flatten a dataset. The flattening is done by the reference variable, which is assumed to be a
tf.RaggedTensor. The shape of the reference variable is used to infer the shape of the other variables. Th

"""
import tensorflow as tf
from typing import Tuple, List, Dict, Union, Optional, Callable

ROOTVariables = Dict[str, tf.RaggedTensor]


def get_ragged_to_dataset_fn(reference_variable: str = 'jets_PartonTruthLabelID') -> Callable[[ROOTVariables], tf.data.Dataset]:
    """Get a function that converts a tf.RaggedTensor to a tf.data.Dataset. The intended use is to use this function 
    in a tf.data.Dataset.interleave call to flatten a dataset. The function will infer the shape of the ragged tensor
    from the shape of the reference variable. Variables toher than the reference variable will be tiled to match the
    shape of the reference variable (they will be duplicated).

    Args:
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            variables to flatten. Defaults to 'jets_PartonTruthLabelID'.

    Returns:
        Callable[[ROOTVariables], tf.data.Dataset]: A function that converts a tf.RaggedTensor to a tf.data.Dataset
    """
    @tf.function
    def _ragged_to_dataset(sample: ROOTVariables) -> tf.data.Dataset:
        sample = sample.copy()
        ragged_shape = tf.shape(sample[reference_variable])
        for key, item in sample.items():
            if isinstance(item, tf.RaggedTensor) and ragged_shape[0] == tf.shape(item)[0]:
                continue
            elif len(tf.shape(item)) == 0:
                sample[key] = tf.tile(item[tf.newaxis, tf.newaxis], [ragged_shape[0], 1])
            elif tf.shape(item)[0] != ragged_shape[0]:
                sample[key] = tf.tile(item[tf.newaxis, :], [ragged_shape[0], 1])
            else:
                continue
        return tf.data.Dataset.from_tensor_slices(sample)
    return _ragged_to_dataset


def get_filter_empty_fn(reference_variable: str = 'jets_PartonTruthLabelID') -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that filters out empty RaggedTensors from a ROOTVariables dictionary. The intended use is to use
    this function in a tf.data.Dataset.filter call to filter out empty events.

    Args:
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            empty events. Defaults to 'jets_PartonTruthLabelID'.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: A function that filters out empty RaggedTensors from a ROOTVariables
    """
    @tf.function
    def _filter_empty(sample: ROOTVariables) -> tf.Tensor:
        return tf.greater(tf.size(sample[reference_variable]), 0)
    return _filter_empty


def get_filter_ragged_values_fn(reference_variable: str = 'jets_PartonTruthLabelID',
                                         wanted_values: List[int] = [1, 2, 3, 4, 5, 6, 21]) -> Callable[[ROOTVariables], ROOTVariables]:
    """Get a function that filters out unwanted values from a ROOTVariables dictionary containg a RaggedTensor. The
    intended use is to use this function in a tf.data.Dataset.map call to filter out unwanted values from a dataset.
    the `get_filter_empty_fn` function should be used after this function to filter out empty events.

    Args:
        reference_variable (str, optional): The variable whose values to filter. Defaults to 'jets_PartonTruthLabelID'.
        wanted_values (List[int], optional): The values to keep. Defaults to [1, 2, 3, 4, 5, 6, 21].

    Returns:
        Callable[[ROOTVariables], ROOTVariables]: A function that filters out unwanted values from a ROOTVariables
    """
    @tf.function
    def _filter_unwanted_ragged_values_fn(sample: ROOTVariables) -> ROOTVariables:
        sample = sample.copy()
        mask = tf.math.reduce_any(tf.math.equal(sample[reference_variable], wanted_values), axis=-1)
        for key, item in sample.items():
            if tf.reduce_all(tf.math.equal(tf.shape(item), tf.shape(mask))):
                sample[key] = tf.ragged.boolean_mask(item, mask)
        return sample
    return _filter_unwanted_ragged_values_fn


def flatten_dataset(dataset: tf.data.Dataset,
                    reference_variable: str = 'jets_PartonTruthLabelID',
                    wanted_values: Optional[List[int]] = None) -> tf.data.Dataset:
    """Apply a series of transformations to a tf.data.Dataset to flatten it. The flattening is done by the reference
    variable, which is assumed to be a tf.RaggedTensor. The shape of the reference variable is used to infer the shape
    of the other variables. The other variables are tiled to match the shape of the reference variable. The dataset is
    then filtered to remove empty events. If wanted_values is not None, the reference variable is filtered to only
    contain the wanted values.

    Args:
        dataset (tf.data.Dataset): The dataset to flatten.
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            variables to flatten. Defaults to 'jets_PartonTruthLabelID'.
        wanted_values (Optional[List[int]], optional): The values to keep in the reference variable. Defaults to None.

    Returns:
        tf.data.Dataset: The flattened dataset
    """

    if wanted_values is None:
        return (
            dataset
            .map(get_filter_empty_fn(reference_variable))
            .interleave(get_ragged_to_dataset_fn(reference_variable))
        )
    else:
        return (
            dataset
            .map(get_filter_ragged_values_fn(reference_variable, wanted_values))
            .map(get_filter_empty_fn(reference_variable))
            .interleave(get_ragged_to_dataset_fn(reference_variable))
        )

Functions

def flatten_dataset(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, reference_variable: str = 'jets_PartonTruthLabelID', wanted_values: Optional[List[int]] = None) ‑> tensorflow.python.data.ops.dataset_ops.DatasetV2

Apply a series of transformations to a tf.data.Dataset to flatten it. The flattening is done by the reference variable, which is assumed to be a tf.RaggedTensor. The shape of the reference variable is used to infer the shape of the other variables. The other variables are tiled to match the shape of the reference variable. The dataset is then filtered to remove empty events. If wanted_values is not None, the reference variable is filtered to only contain the wanted values.

Args

dataset : tf.data.Dataset
The dataset to flatten.
reference_variable : str, optional
The variable to use as reference for infering the shape of variables to flatten. Defaults to 'jets_PartonTruthLabelID'.
wanted_values : Optional[List[int]], optional
The values to keep in the reference variable. Defaults to None.

Returns

tf.data.Dataset
The flattened dataset
Expand source code
def flatten_dataset(dataset: tf.data.Dataset,
                    reference_variable: str = 'jets_PartonTruthLabelID',
                    wanted_values: Optional[List[int]] = None) -> tf.data.Dataset:
    """Apply a series of transformations to a tf.data.Dataset to flatten it. The flattening is done by the reference
    variable, which is assumed to be a tf.RaggedTensor. The shape of the reference variable is used to infer the shape
    of the other variables. The other variables are tiled to match the shape of the reference variable. The dataset is
    then filtered to remove empty events. If wanted_values is not None, the reference variable is filtered to only
    contain the wanted values.

    Args:
        dataset (tf.data.Dataset): The dataset to flatten.
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            variables to flatten. Defaults to 'jets_PartonTruthLabelID'.
        wanted_values (Optional[List[int]], optional): The values to keep in the reference variable. Defaults to None.

    Returns:
        tf.data.Dataset: The flattened dataset
    """

    if wanted_values is None:
        return (
            dataset
            .map(get_filter_empty_fn(reference_variable))
            .interleave(get_ragged_to_dataset_fn(reference_variable))
        )
    else:
        return (
            dataset
            .map(get_filter_ragged_values_fn(reference_variable, wanted_values))
            .map(get_filter_empty_fn(reference_variable))
            .interleave(get_ragged_to_dataset_fn(reference_variable))
        )
def get_filter_empty_fn(reference_variable: str = 'jets_PartonTruthLabelID') ‑> Callable[[Dict[str, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]], tensorflow.python.framework.ops.Tensor]

Get a function that filters out empty RaggedTensors from a ROOTVariables dictionary. The intended use is to use this function in a tf.data.Dataset.filter call to filter out empty events.

Args

reference_variable : str, optional
The variable to use as reference for infering the shape of empty events. Defaults to 'jets_PartonTruthLabelID'.

Returns

Callable[[ROOTVariables], tf.Tensor]
A function that filters out empty RaggedTensors from a ROOTVariables
Expand source code
def get_filter_empty_fn(reference_variable: str = 'jets_PartonTruthLabelID') -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that filters out empty RaggedTensors from a ROOTVariables dictionary. The intended use is to use
    this function in a tf.data.Dataset.filter call to filter out empty events.

    Args:
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            empty events. Defaults to 'jets_PartonTruthLabelID'.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: A function that filters out empty RaggedTensors from a ROOTVariables
    """
    @tf.function
    def _filter_empty(sample: ROOTVariables) -> tf.Tensor:
        return tf.greater(tf.size(sample[reference_variable]), 0)
    return _filter_empty
def get_filter_ragged_values_fn(reference_variable: str = 'jets_PartonTruthLabelID', wanted_values: List[int] = [1, 2, 3, 4, 5, 6, 21]) ‑> Callable[[Dict[str, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]], Dict[str, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]

Get a function that filters out unwanted values from a ROOTVariables dictionary containg a RaggedTensor. The intended use is to use this function in a tf.data.Dataset.map call to filter out unwanted values from a dataset. the get_filter_empty_fn() function should be used after this function to filter out empty events.

Args

reference_variable : str, optional
The variable whose values to filter. Defaults to 'jets_PartonTruthLabelID'.
wanted_values : List[int], optional
The values to keep. Defaults to [1, 2, 3, 4, 5, 6, 21].

Returns

Callable[[ROOTVariables], ROOTVariables]
A function that filters out unwanted values from a ROOTVariables
Expand source code
def get_filter_ragged_values_fn(reference_variable: str = 'jets_PartonTruthLabelID',
                                         wanted_values: List[int] = [1, 2, 3, 4, 5, 6, 21]) -> Callable[[ROOTVariables], ROOTVariables]:
    """Get a function that filters out unwanted values from a ROOTVariables dictionary containg a RaggedTensor. The
    intended use is to use this function in a tf.data.Dataset.map call to filter out unwanted values from a dataset.
    the `get_filter_empty_fn` function should be used after this function to filter out empty events.

    Args:
        reference_variable (str, optional): The variable whose values to filter. Defaults to 'jets_PartonTruthLabelID'.
        wanted_values (List[int], optional): The values to keep. Defaults to [1, 2, 3, 4, 5, 6, 21].

    Returns:
        Callable[[ROOTVariables], ROOTVariables]: A function that filters out unwanted values from a ROOTVariables
    """
    @tf.function
    def _filter_unwanted_ragged_values_fn(sample: ROOTVariables) -> ROOTVariables:
        sample = sample.copy()
        mask = tf.math.reduce_any(tf.math.equal(sample[reference_variable], wanted_values), axis=-1)
        for key, item in sample.items():
            if tf.reduce_all(tf.math.equal(tf.shape(item), tf.shape(mask))):
                sample[key] = tf.ragged.boolean_mask(item, mask)
        return sample
    return _filter_unwanted_ragged_values_fn
def get_ragged_to_dataset_fn(reference_variable: str = 'jets_PartonTruthLabelID') ‑> Callable[[Dict[str, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]], tensorflow.python.data.ops.dataset_ops.DatasetV2]

Get a function that converts a tf.RaggedTensor to a tf.data.Dataset. The intended use is to use this function in a tf.data.Dataset.interleave call to flatten a dataset. The function will infer the shape of the ragged tensor from the shape of the reference variable. Variables toher than the reference variable will be tiled to match the shape of the reference variable (they will be duplicated).

Args

reference_variable : str, optional
The variable to use as reference for infering the shape of variables to flatten. Defaults to 'jets_PartonTruthLabelID'.

Returns

Callable[[ROOTVariables], tf.data.Dataset]
A function that converts a tf.RaggedTensor to a tf.data.Dataset
Expand source code
def get_ragged_to_dataset_fn(reference_variable: str = 'jets_PartonTruthLabelID') -> Callable[[ROOTVariables], tf.data.Dataset]:
    """Get a function that converts a tf.RaggedTensor to a tf.data.Dataset. The intended use is to use this function 
    in a tf.data.Dataset.interleave call to flatten a dataset. The function will infer the shape of the ragged tensor
    from the shape of the reference variable. Variables toher than the reference variable will be tiled to match the
    shape of the reference variable (they will be duplicated).

    Args:
        reference_variable (str, optional): The variable to use as reference for infering the shape of
            variables to flatten. Defaults to 'jets_PartonTruthLabelID'.

    Returns:
        Callable[[ROOTVariables], tf.data.Dataset]: A function that converts a tf.RaggedTensor to a tf.data.Dataset
    """
    @tf.function
    def _ragged_to_dataset(sample: ROOTVariables) -> tf.data.Dataset:
        sample = sample.copy()
        ragged_shape = tf.shape(sample[reference_variable])
        for key, item in sample.items():
            if isinstance(item, tf.RaggedTensor) and ragged_shape[0] == tf.shape(item)[0]:
                continue
            elif len(tf.shape(item)) == 0:
                sample[key] = tf.tile(item[tf.newaxis, tf.newaxis], [ragged_shape[0], 1])
            elif tf.shape(item)[0] != ragged_shape[0]:
                sample[key] = tf.tile(item[tf.newaxis, :], [ragged_shape[0], 1])
            else:
                continue
        return tf.data.Dataset.from_tensor_slices(sample)
    return _ragged_to_dataset