Module jidenn.preprocess.resampling

Expand source code
import tensorflow as tf
import numpy as np
from typing import Callable, List, Tuple, Union, Optional

ROOTVariables = dict[str, Union[tf.Tensor, tf.RaggedTensor]]


@tf.function
def take_second(x, y):
    return y

@tf.function
def take_first(x, y):
    return x

def get_bin_fn(bins: Union[int, List[float]] = 100,
               lower_var_limit: int = 60_000,
               upper_var_limit: int = 5_600_000,
               log_binning_base: Optional[Union[int, float]] = None,
               variable: str = 'jets_pt') -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that returns the bin index of a variable. The intended use is to use this function
    with tf.data.Dataset.rejection_resample() to resample a dataset to a target distribution.

    Args:
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (int, optional): The lower limit of the variable. Defaults to 60_000.
        upper_var_limit (int, optional): The upper limit of the variable. Defaults to 5_600_000.
        variable (str, optional): Variable to bin. Defaults to 'jets_pt'.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that returns the bin index of a variable.
    """
    
    if log_binning_base is not None and isinstance(bins, int):
        bin_edges =  np.logspace(np.log(lower_var_limit), np.log(upper_var_limit), bins + 1, base=log_binning_base)
        bin_edges = tf.constant(bin_edges, dtype=tf.float32)
    elif isinstance(bins, int):
        lower_var_limit_casted = tf.cast(lower_var_limit, dtype=tf.float32)
        upper_var_limit_casted = tf.cast(upper_var_limit, dtype=tf.float32)
        bin_edges = tf.linspace(lower_var_limit_casted, upper_var_limit_casted, bins + 1)
    else:
        bin_edges = tf.constant(bins, dtype=tf.float32)
        
    @tf.function
    def _rebin(data: ROOTVariables) -> tf.Tensor:
        var = data[variable]
        var = tf.reshape(var, (1, ))
        bin_indices = tf.searchsorted(bin_edges, var, side='right') - 1
        bin_indices = tf.where(bin_indices < 0, 0, bin_indices)
        bin_indices = tf.clip_by_value(bin_indices, 0, len(bin_edges) - 2)
        bin_indices = tf.squeeze(bin_indices)
        return bin_indices
    return _rebin


def resample_dataset(dataset: tf.data.Dataset,
                     bins: int = 100,
                     lower_var_limit: int = 60_000,
                     upper_var_limit: int = 5_600_000,
                     variable: str = 'jets_pt',
                     log_binning_base: Optional[Union[int, float]] = None,
                     target_dist: Optional[List[int]] = None,
                     precompute_init_dist: Optional[bool] = False) -> tf.data.Dataset:
    """Resample a dataset to a target distribution. 

    Args:
        dataset (tf.data.Dataset): The dataset to resample.
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (int, optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (int, optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Variable to bin. Defaults to 'jets_pt'.
        target_dist (Optional[List[int]], optional): The target distribution. If None, a uniform distribution
            is used. Defaults to None. It must have the same length as the number of bins.

    Returns:
        tf.data.Dataset: The resampled dataset.
    """

    dist = [1 / bins] * bins if target_dist is None else target_dist
    
    if target_dist is None and precompute_init_dist: 
        initial_dist = dataset.map(get_bin_fn(bins=bins,
                                            lower_var_limit=lower_var_limit,
                                            upper_var_limit=upper_var_limit,
                                            log_binning_base=log_binning_base,
                                            variable=variable))
        
        initial_dist = initial_dist.reduce(tf.zeros(bins, dtype=tf.int64), lambda x, y: x + tf.one_hot(y, bins, dtype=tf.int64))
        initial_dist = tf.cast(initial_dist, dtype=tf.float64)
        initial_dist = initial_dist / tf.reduce_sum(initial_dist)
        # initial_dist = tf.squeeze(initial_dist)
        print(initial_dist)
    else:
        initial_dist = None
    
    
    dataset = dataset.rejection_resample(get_bin_fn(bins=bins,
                                                    lower_var_limit=lower_var_limit,
                                                    log_binning_base=log_binning_base,
                                                    upper_var_limit=upper_var_limit,
                                                    variable=variable), target_dist=dist, initial_dist=initial_dist)
    return dataset.map(take_second)


def get_label_bin_fn(bins: int = 100,
                     lower_var_limit: Union[int, float] = 60_000,
                     upper_var_limit: Union[int, float] = 5_600_000,
                     variable: str = 'jets_pt',
                     label_variable: str = 'jets_PartonTruthLabelID',
                     label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                     log_binning_base: Optional[Union[int, float]] = None,
                     label_class_2: List[int] = [21]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that returns the bin index of a data sample based on the value of the binned 
    variable and the label. The intended use is to use this function with tf.data.Dataset.rejection_resample()
    to resample a dataset to a target distribution.


    Args:
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (Union[int, float], optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (Union[int, float], optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Name of the variable to bin. Defaults to 'jets_pt'.
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
            The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.
            The values will be regarded as the same class. Defaults to [21].

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: _description_
    """
    if log_binning_base is not None and isinstance(bins, int):
        bin_edges =  np.logspace(np.log(lower_var_limit), np.log(upper_var_limit), bins + 1, base=log_binning_base)
        bin_edges = tf.constant(bin_edges, dtype=tf.float32)
    elif isinstance(bins, int):
        lower_var_limit_casted = tf.cast(lower_var_limit, dtype=tf.float32)
        upper_var_limit_casted = tf.cast(upper_var_limit, dtype=tf.float32)
        bin_edges = tf.linspace(lower_var_limit_casted, upper_var_limit_casted, bins + 1)
    else:
        bin_edges = tf.constant(bins, dtype=tf.float32)
        
    n_bins = len(bin_edges) - 1
        
    @tf.function
    def _rebin(data: ROOTVariables) -> tf.Tensor:
        var = data[variable]
        label = data[label_variable]
        var = tf.reshape(var, (1, ))
        bin_indices = tf.searchsorted(bin_edges, var, side='right') - 1
        # bin_indices = tf.where(bin_indices < 0, 0, bin_indices)
        bin_indices = tf.clip_by_value(bin_indices, 0, n_bins - 1)
        bin_indices = tf.squeeze(bin_indices)
        if tf.reduce_any(tf.equal(label, tf.constant(label_class_1, dtype=label.dtype))):
            return bin_indices
        elif tf.reduce_any(tf.equal(label, tf.constant(label_class_2, dtype=label.dtype))):
            return bin_indices + n_bins
        else:
            return tf.constant(0, dtype=tf.int32)
    return _rebin


def rasample_from_min_bin_count(dataset: tf.data.Dataset,
                                bin_fn: Callable[[ROOTVariables], tf.Tensor],
                                min_bin_count: int,
                                bins: int = 100,) -> tf.data.Dataset:
    
    @tf.function
    def assign_if_filter_fn(state, x):
        bin_idx = bin_fn(x)
        state = state + tf.one_hot(bin_idx, bins, dtype=tf.int64)
        return state, (x, tf.greater(min_bin_count, state[bin_idx]))

    dataset = dataset.scan(tf.zeros((bins), dtype=tf.int64), assign_if_filter_fn)
    dataset = dataset.filter(take_second).map(take_first)
    
    return dataset
                                


def resample_with_labels_dataset(dataset: tf.data.Dataset,
                                 bins: int = 100,
                                 lower_var_limit: int = 60_000,
                                 upper_var_limit: int = 5_600_000,
                                 variable: str = 'jets_pt',
                                 label_variable: str = 'jets_PartonTruthLabelID',
                                 label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                                 label_class_2: List[int] = [21],
                                 target_dist: Optional[List[int]] = None,
                                 log_binning_base: Optional[Union[int, float]] = None,
                                 precompute_init_dist: Optional[bool] = False,
                                 from_min_count: bool = False) -> tf.data.Dataset:
    """Resample a dataset to a target distribution based on the value of the label variable and the binned variable.


    Args:
        dataset (tf.data.Dataset): The dataset to resample.
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (Union[int, float], optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (Union[int, float], optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Name of the variable to bin. Defaults to 'jets_pt'.
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
            The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.
            The values will be regarded as the same class. Defaults to [21].
        target_dist (Optional[List[int]], optional): The target distribution. If None, a uniform distribution
            is used. Defaults to None. It must have the same length as the number of bins.

    Returns:
        tf.data.Dataset: The resampled dataset.
    """
    dist = [1 / (bins * 2)] * bins * 2 if target_dist is None else target_dist
    
    if target_dist is None and precompute_init_dist: 
        initial_dist = dataset.map(get_label_bin_fn(bins=bins,
                                                            lower_var_limit=lower_var_limit,
                                                            upper_var_limit=upper_var_limit,
                                                            variable=variable,
                                                            label_variable=label_variable,
                                                            log_binning_base=log_binning_base,
                                                            label_class_1=label_class_1,
                                                            label_class_2=label_class_2))
        initial_dist = initial_dist.reduce(tf.zeros((bins * 2), dtype=tf.int64), lambda x, y: x + tf.one_hot(y, bins * 2, dtype=tf.int64))
        min_count = tf.reduce_min(initial_dist).numpy()
        min_count = tf.cast(min_count, dtype=tf.int64)
        initial_dist = tf.cast(initial_dist, dtype=tf.float64)
        print(f'Min count: {min_count:,}')
        print(f'Estimated number of samples: {min_count * bins * 2:,}')
        initial_dist = initial_dist / tf.reduce_sum(initial_dist)
        print(initial_dist)
    else:
        initial_dist = None
        
    label_bin_fn = get_label_bin_fn(bins=bins,
                                    lower_var_limit=lower_var_limit,
                                    upper_var_limit=upper_var_limit,
                                    variable=variable,
                                    log_binning_base=log_binning_base,
                                    label_variable=label_variable,
                                    label_class_1=label_class_1,
                                    label_class_2=label_class_2)
        
    if from_min_count:
        return rasample_from_min_bin_count(dataset, label_bin_fn, min_bin_count=min_count, bins=bins * 2)
    else:
        dataset = dataset.rejection_resample(label_bin_fn, target_dist=dist, initial_dist=initial_dist)
        return dataset.map(take_second)


def write_new_variable(variable_value: int, variable_name: str = 'JZ_slice') -> Callable[[ROOTVariables], ROOTVariables]:
    """Get a function that writes a new variable to a dataset. The intended use is to use this function
    with tf.data.Dataset.map() to add a new variable to a dataset. The value is the same for all samples.

    Args:
        variable_value (int): The value of the new variable.
        variable_name (str, optional): The name of the new variable. Defaults to 'JZ_slice'.

    Returns:
        Callable[[ROOTVariables], ROOTVariables]: Function that writes a new variable to a dataset.
    """

    @tf.function
    def _write(data: ROOTVariables) -> ROOTVariables:
        new_data = data.copy()
        new_data[variable_name] = variable_value
        return new_data
    return _write


def get_label_splitting_fn(label_variable: str = 'jets_PartonTruthLabelID',
                           label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                           label_class_2: List[int] = [21]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that splits a dataset based on the value of the label variable. The intended use is to use this function
    with tf.data.Dataset.filter() to split a dataset into two classes.

    Args:
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that splits a dataset based on the value of the label variable.
    """
    @tf.function
    def _split(data: ROOTVariables) -> tf.Tensor:
        label = data[label_variable]
        if tf.reduce_any(tf.equal(label, tf.constant(label_class_1, dtype=tf.int32))):
            return tf.constant(1, dtype=tf.int32)
        elif tf.reduce_any(tf.equal(label, tf.constant(label_class_2, dtype=tf.int32))):
            return tf.constant(0, dtype=tf.int32)
        else:
            return tf.constant(0, dtype=tf.int32)
    return _split


def get_cut_fn(variable: str = 'jets_pt', lower_limit: Optional[float] = None, upper_limit: Optional[float] = None) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that cuts a dataset based on the value of a variable. The intended use is to use this function
    with tf.data.Dataset.filter() to cut a dataset.

    Args:
        variable (str, optional): Name of the variable to cut. Defaults to 'jets_pt'.
        lower_limit (Optional[float], optional): The lower limit of the variable. Defaults to None.
        upper_limit (Optional[float], optional): The upper limit of the variable. Defaults to None.

    Raises:
        ValueError: At least one of the limits must be specified.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that cuts a dataset based on the value of a variable.
    """

    if lower_limit is None and upper_limit is None:
        raise ValueError('At least one of the limits must be specified.')

    @tf.function
    def _low_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all(data[variable] > lower_limit)

    @tf.function
    def _up_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all(data[variable] < upper_limit)

    @tf.function
    def _up_low_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all([data[variable] > lower_limit, data[variable] < upper_limit])

    if lower_limit is None:
        return _up_cut
    elif upper_limit is None:
        return _low_cut
    else:
        return _up_low_cut


def get_filter_fn(variable: str, values: List[Union[float, int]]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that filters a dataset based on the value of a variable. The intended use is to use this function
    with tf.data.Dataset.filter() to filter a dataset.

    Args:
        variable (str): Name of the variable to filter.
        values (List[Union[float, int]]): Values to keep.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that filters a dataset based on the value of a variable.
    """
    @tf.function
    def _filter(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_any(tf.equal(data[variable], tf.constant(values)))
    return _filter

Functions

def get_bin_fn(bins: Union[int, List[float]] = 100, lower_var_limit: int = 60000, upper_var_limit: int = 5600000, log_binning_base: Union[int, float, ForwardRef(None)] = None, variable: str = 'jets_pt') ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor]

Get a function that returns the bin index of a variable. The intended use is to use this function with tf.data.Dataset.rejection_resample() to resample a dataset to a target distribution.

Args

n_bins : int, optional
Number of bins to use. Defaults to 100.
lower_var_limit : int, optional
The lower limit of the variable. Defaults to 60_000.
upper_var_limit : int, optional
The upper limit of the variable. Defaults to 5_600_000.
variable : str, optional
Variable to bin. Defaults to 'jets_pt'.

Returns

Callable[[ROOTVariables], tf.Tensor]
Function that returns the bin index of a variable.
Expand source code
def get_bin_fn(bins: Union[int, List[float]] = 100,
               lower_var_limit: int = 60_000,
               upper_var_limit: int = 5_600_000,
               log_binning_base: Optional[Union[int, float]] = None,
               variable: str = 'jets_pt') -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that returns the bin index of a variable. The intended use is to use this function
    with tf.data.Dataset.rejection_resample() to resample a dataset to a target distribution.

    Args:
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (int, optional): The lower limit of the variable. Defaults to 60_000.
        upper_var_limit (int, optional): The upper limit of the variable. Defaults to 5_600_000.
        variable (str, optional): Variable to bin. Defaults to 'jets_pt'.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that returns the bin index of a variable.
    """
    
    if log_binning_base is not None and isinstance(bins, int):
        bin_edges =  np.logspace(np.log(lower_var_limit), np.log(upper_var_limit), bins + 1, base=log_binning_base)
        bin_edges = tf.constant(bin_edges, dtype=tf.float32)
    elif isinstance(bins, int):
        lower_var_limit_casted = tf.cast(lower_var_limit, dtype=tf.float32)
        upper_var_limit_casted = tf.cast(upper_var_limit, dtype=tf.float32)
        bin_edges = tf.linspace(lower_var_limit_casted, upper_var_limit_casted, bins + 1)
    else:
        bin_edges = tf.constant(bins, dtype=tf.float32)
        
    @tf.function
    def _rebin(data: ROOTVariables) -> tf.Tensor:
        var = data[variable]
        var = tf.reshape(var, (1, ))
        bin_indices = tf.searchsorted(bin_edges, var, side='right') - 1
        bin_indices = tf.where(bin_indices < 0, 0, bin_indices)
        bin_indices = tf.clip_by_value(bin_indices, 0, len(bin_edges) - 2)
        bin_indices = tf.squeeze(bin_indices)
        return bin_indices
    return _rebin
def get_cut_fn(variable: str = 'jets_pt', lower_limit: Optional[float] = None, upper_limit: Optional[float] = None) ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor]

Get a function that cuts a dataset based on the value of a variable. The intended use is to use this function with tf.data.Dataset.filter() to cut a dataset.

Args

variable : str, optional
Name of the variable to cut. Defaults to 'jets_pt'.
lower_limit : Optional[float], optional
The lower limit of the variable. Defaults to None.
upper_limit : Optional[float], optional
The upper limit of the variable. Defaults to None.

Raises

ValueError
At least one of the limits must be specified.

Returns

Callable[[ROOTVariables], tf.Tensor]
Function that cuts a dataset based on the value of a variable.
Expand source code
def get_cut_fn(variable: str = 'jets_pt', lower_limit: Optional[float] = None, upper_limit: Optional[float] = None) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that cuts a dataset based on the value of a variable. The intended use is to use this function
    with tf.data.Dataset.filter() to cut a dataset.

    Args:
        variable (str, optional): Name of the variable to cut. Defaults to 'jets_pt'.
        lower_limit (Optional[float], optional): The lower limit of the variable. Defaults to None.
        upper_limit (Optional[float], optional): The upper limit of the variable. Defaults to None.

    Raises:
        ValueError: At least one of the limits must be specified.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that cuts a dataset based on the value of a variable.
    """

    if lower_limit is None and upper_limit is None:
        raise ValueError('At least one of the limits must be specified.')

    @tf.function
    def _low_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all(data[variable] > lower_limit)

    @tf.function
    def _up_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all(data[variable] < upper_limit)

    @tf.function
    def _up_low_cut(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_all([data[variable] > lower_limit, data[variable] < upper_limit])

    if lower_limit is None:
        return _up_cut
    elif upper_limit is None:
        return _low_cut
    else:
        return _up_low_cut
def get_filter_fn(variable: str, values: List[Union[float, int]]) ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor]

Get a function that filters a dataset based on the value of a variable. The intended use is to use this function with tf.data.Dataset.filter() to filter a dataset.

Args

variable : str
Name of the variable to filter.
values : List[Union[float, int]]
Values to keep.

Returns

Callable[[ROOTVariables], tf.Tensor]
Function that filters a dataset based on the value of a variable.
Expand source code
def get_filter_fn(variable: str, values: List[Union[float, int]]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that filters a dataset based on the value of a variable. The intended use is to use this function
    with tf.data.Dataset.filter() to filter a dataset.

    Args:
        variable (str): Name of the variable to filter.
        values (List[Union[float, int]]): Values to keep.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that filters a dataset based on the value of a variable.
    """
    @tf.function
    def _filter(data: ROOTVariables) -> tf.Tensor:
        return tf.reduce_any(tf.equal(data[variable], tf.constant(values)))
    return _filter
def get_label_bin_fn(bins: int = 100, lower_var_limit: Union[int, float] = 60000, upper_var_limit: Union[int, float] = 5600000, variable: str = 'jets_pt', label_variable: str = 'jets_PartonTruthLabelID', label_class_1: List[int] = [1, 2, 3, 4, 5, 6], log_binning_base: Union[int, float, ForwardRef(None)] = None, label_class_2: List[int] = [21]) ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor]

Get a function that returns the bin index of a data sample based on the value of the binned variable and the label. The intended use is to use this function with tf.data.Dataset.rejection_resample() to resample a dataset to a target distribution.

Args

n_bins : int, optional
Number of bins to use. Defaults to 100.
lower_var_limit : Union[int, float], optional
The lower limit of the binned variable (the lowest bin). Defaults to 60_000.
upper_var_limit : Union[int, float], optional
The upper limit of the binned variable (the highest bin). Defaults to 5_600_000.
variable : str, optional
Name of the variable to bin. Defaults to 'jets_pt'.
label_variable : str, optional
Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
label_class_1 : List[int], optional
Values of the label variable that correspond to the first class. The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
label_class_2 : List[int], optional
Values of the label variable that correspond to the second class. The values will be regarded as the same class. Defaults to [21].

Returns

Callable[[ROOTVariables], tf.Tensor]
description
Expand source code
def get_label_bin_fn(bins: int = 100,
                     lower_var_limit: Union[int, float] = 60_000,
                     upper_var_limit: Union[int, float] = 5_600_000,
                     variable: str = 'jets_pt',
                     label_variable: str = 'jets_PartonTruthLabelID',
                     label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                     log_binning_base: Optional[Union[int, float]] = None,
                     label_class_2: List[int] = [21]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that returns the bin index of a data sample based on the value of the binned 
    variable and the label. The intended use is to use this function with tf.data.Dataset.rejection_resample()
    to resample a dataset to a target distribution.


    Args:
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (Union[int, float], optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (Union[int, float], optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Name of the variable to bin. Defaults to 'jets_pt'.
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
            The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.
            The values will be regarded as the same class. Defaults to [21].

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: _description_
    """
    if log_binning_base is not None and isinstance(bins, int):
        bin_edges =  np.logspace(np.log(lower_var_limit), np.log(upper_var_limit), bins + 1, base=log_binning_base)
        bin_edges = tf.constant(bin_edges, dtype=tf.float32)
    elif isinstance(bins, int):
        lower_var_limit_casted = tf.cast(lower_var_limit, dtype=tf.float32)
        upper_var_limit_casted = tf.cast(upper_var_limit, dtype=tf.float32)
        bin_edges = tf.linspace(lower_var_limit_casted, upper_var_limit_casted, bins + 1)
    else:
        bin_edges = tf.constant(bins, dtype=tf.float32)
        
    n_bins = len(bin_edges) - 1
        
    @tf.function
    def _rebin(data: ROOTVariables) -> tf.Tensor:
        var = data[variable]
        label = data[label_variable]
        var = tf.reshape(var, (1, ))
        bin_indices = tf.searchsorted(bin_edges, var, side='right') - 1
        # bin_indices = tf.where(bin_indices < 0, 0, bin_indices)
        bin_indices = tf.clip_by_value(bin_indices, 0, n_bins - 1)
        bin_indices = tf.squeeze(bin_indices)
        if tf.reduce_any(tf.equal(label, tf.constant(label_class_1, dtype=label.dtype))):
            return bin_indices
        elif tf.reduce_any(tf.equal(label, tf.constant(label_class_2, dtype=label.dtype))):
            return bin_indices + n_bins
        else:
            return tf.constant(0, dtype=tf.int32)
    return _rebin
def get_label_splitting_fn(label_variable: str = 'jets_PartonTruthLabelID', label_class_1: List[int] = [1, 2, 3, 4, 5, 6], label_class_2: List[int] = [21]) ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor]

Get a function that splits a dataset based on the value of the label variable. The intended use is to use this function with tf.data.Dataset.filter() to split a dataset into two classes.

Args

label_variable : str, optional
Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
label_class_1 : List[int], optional
Values of the label variable that correspond to the first class.
label_class_2 : List[int], optional
Values of the label variable that correspond to the second class.

Returns

Callable[[ROOTVariables], tf.Tensor]
Function that splits a dataset based on the value of the label variable.
Expand source code
def get_label_splitting_fn(label_variable: str = 'jets_PartonTruthLabelID',
                           label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                           label_class_2: List[int] = [21]) -> Callable[[ROOTVariables], tf.Tensor]:
    """Get a function that splits a dataset based on the value of the label variable. The intended use is to use this function
    with tf.data.Dataset.filter() to split a dataset into two classes.

    Args:
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.

    Returns:
        Callable[[ROOTVariables], tf.Tensor]: Function that splits a dataset based on the value of the label variable.
    """
    @tf.function
    def _split(data: ROOTVariables) -> tf.Tensor:
        label = data[label_variable]
        if tf.reduce_any(tf.equal(label, tf.constant(label_class_1, dtype=tf.int32))):
            return tf.constant(1, dtype=tf.int32)
        elif tf.reduce_any(tf.equal(label, tf.constant(label_class_2, dtype=tf.int32))):
            return tf.constant(0, dtype=tf.int32)
        else:
            return tf.constant(0, dtype=tf.int32)
    return _split
def rasample_from_min_bin_count(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, bin_fn: Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], tensorflow.python.framework.ops.Tensor], min_bin_count: int, bins: int = 100) ‑> tensorflow.python.data.ops.dataset_ops.DatasetV2
Expand source code
def rasample_from_min_bin_count(dataset: tf.data.Dataset,
                                bin_fn: Callable[[ROOTVariables], tf.Tensor],
                                min_bin_count: int,
                                bins: int = 100,) -> tf.data.Dataset:
    
    @tf.function
    def assign_if_filter_fn(state, x):
        bin_idx = bin_fn(x)
        state = state + tf.one_hot(bin_idx, bins, dtype=tf.int64)
        return state, (x, tf.greater(min_bin_count, state[bin_idx]))

    dataset = dataset.scan(tf.zeros((bins), dtype=tf.int64), assign_if_filter_fn)
    dataset = dataset.filter(take_second).map(take_first)
    
    return dataset
def resample_dataset(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, bins: int = 100, lower_var_limit: int = 60000, upper_var_limit: int = 5600000, variable: str = 'jets_pt', log_binning_base: Union[int, float, ForwardRef(None)] = None, target_dist: Optional[List[int]] = None, precompute_init_dist: Optional[bool] = False) ‑> tensorflow.python.data.ops.dataset_ops.DatasetV2

Resample a dataset to a target distribution.

Args

dataset : tf.data.Dataset
The dataset to resample.
n_bins : int, optional
Number of bins to use. Defaults to 100.
lower_var_limit : int, optional
The lower limit of the binned variable (the lowest bin). Defaults to 60_000.
upper_var_limit : int, optional
The upper limit of the binned variable (the highest bin). Defaults to 5_600_000.
variable : str, optional
Variable to bin. Defaults to 'jets_pt'.
target_dist : Optional[List[int]], optional
The target distribution. If None, a uniform distribution is used. Defaults to None. It must have the same length as the number of bins.

Returns

tf.data.Dataset
The resampled dataset.
Expand source code
def resample_dataset(dataset: tf.data.Dataset,
                     bins: int = 100,
                     lower_var_limit: int = 60_000,
                     upper_var_limit: int = 5_600_000,
                     variable: str = 'jets_pt',
                     log_binning_base: Optional[Union[int, float]] = None,
                     target_dist: Optional[List[int]] = None,
                     precompute_init_dist: Optional[bool] = False) -> tf.data.Dataset:
    """Resample a dataset to a target distribution. 

    Args:
        dataset (tf.data.Dataset): The dataset to resample.
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (int, optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (int, optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Variable to bin. Defaults to 'jets_pt'.
        target_dist (Optional[List[int]], optional): The target distribution. If None, a uniform distribution
            is used. Defaults to None. It must have the same length as the number of bins.

    Returns:
        tf.data.Dataset: The resampled dataset.
    """

    dist = [1 / bins] * bins if target_dist is None else target_dist
    
    if target_dist is None and precompute_init_dist: 
        initial_dist = dataset.map(get_bin_fn(bins=bins,
                                            lower_var_limit=lower_var_limit,
                                            upper_var_limit=upper_var_limit,
                                            log_binning_base=log_binning_base,
                                            variable=variable))
        
        initial_dist = initial_dist.reduce(tf.zeros(bins, dtype=tf.int64), lambda x, y: x + tf.one_hot(y, bins, dtype=tf.int64))
        initial_dist = tf.cast(initial_dist, dtype=tf.float64)
        initial_dist = initial_dist / tf.reduce_sum(initial_dist)
        # initial_dist = tf.squeeze(initial_dist)
        print(initial_dist)
    else:
        initial_dist = None
    
    
    dataset = dataset.rejection_resample(get_bin_fn(bins=bins,
                                                    lower_var_limit=lower_var_limit,
                                                    log_binning_base=log_binning_base,
                                                    upper_var_limit=upper_var_limit,
                                                    variable=variable), target_dist=dist, initial_dist=initial_dist)
    return dataset.map(take_second)
def resample_with_labels_dataset(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, bins: int = 100, lower_var_limit: int = 60000, upper_var_limit: int = 5600000, variable: str = 'jets_pt', label_variable: str = 'jets_PartonTruthLabelID', label_class_1: List[int] = [1, 2, 3, 4, 5, 6], label_class_2: List[int] = [21], target_dist: Optional[List[int]] = None, log_binning_base: Union[int, float, ForwardRef(None)] = None, precompute_init_dist: Optional[bool] = False, from_min_count: bool = False) ‑> tensorflow.python.data.ops.dataset_ops.DatasetV2

Resample a dataset to a target distribution based on the value of the label variable and the binned variable.

Args

dataset : tf.data.Dataset
The dataset to resample.
n_bins : int, optional
Number of bins to use. Defaults to 100.
lower_var_limit : Union[int, float], optional
The lower limit of the binned variable (the lowest bin). Defaults to 60_000.
upper_var_limit : Union[int, float], optional
The upper limit of the binned variable (the highest bin). Defaults to 5_600_000.
variable : str, optional
Name of the variable to bin. Defaults to 'jets_pt'.
label_variable : str, optional
Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
label_class_1 : List[int], optional
Values of the label variable that correspond to the first class. The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
label_class_2 : List[int], optional
Values of the label variable that correspond to the second class. The values will be regarded as the same class. Defaults to [21].
target_dist : Optional[List[int]], optional
The target distribution. If None, a uniform distribution is used. Defaults to None. It must have the same length as the number of bins.

Returns

tf.data.Dataset
The resampled dataset.
Expand source code
def resample_with_labels_dataset(dataset: tf.data.Dataset,
                                 bins: int = 100,
                                 lower_var_limit: int = 60_000,
                                 upper_var_limit: int = 5_600_000,
                                 variable: str = 'jets_pt',
                                 label_variable: str = 'jets_PartonTruthLabelID',
                                 label_class_1: List[int] = [1, 2, 3, 4, 5, 6],
                                 label_class_2: List[int] = [21],
                                 target_dist: Optional[List[int]] = None,
                                 log_binning_base: Optional[Union[int, float]] = None,
                                 precompute_init_dist: Optional[bool] = False,
                                 from_min_count: bool = False) -> tf.data.Dataset:
    """Resample a dataset to a target distribution based on the value of the label variable and the binned variable.


    Args:
        dataset (tf.data.Dataset): The dataset to resample.
        n_bins (int, optional): Number of bins to use. Defaults to 100.
        lower_var_limit (Union[int, float], optional): The lower limit of the binned variable (the lowest bin). 
            Defaults to 60_000.
        upper_var_limit (Union[int, float], optional): The upper limit of the binned variable (the highest bin).
            Defaults to 5_600_000.
        variable (str, optional): Name of the variable to bin. Defaults to 'jets_pt'.
        label_variable (str, optional): Name of the label variable. Defaults to 'jets_PartonTruthLabelID'.
        label_class_1 (List[int], optional): Values of the label variable that correspond to the first class.
            The values will be regarded as the same class. Defaults to [1, 2, 3, 4, 5, 6].
        label_class_2 (List[int], optional): Values of the label variable that correspond to the second class.
            The values will be regarded as the same class. Defaults to [21].
        target_dist (Optional[List[int]], optional): The target distribution. If None, a uniform distribution
            is used. Defaults to None. It must have the same length as the number of bins.

    Returns:
        tf.data.Dataset: The resampled dataset.
    """
    dist = [1 / (bins * 2)] * bins * 2 if target_dist is None else target_dist
    
    if target_dist is None and precompute_init_dist: 
        initial_dist = dataset.map(get_label_bin_fn(bins=bins,
                                                            lower_var_limit=lower_var_limit,
                                                            upper_var_limit=upper_var_limit,
                                                            variable=variable,
                                                            label_variable=label_variable,
                                                            log_binning_base=log_binning_base,
                                                            label_class_1=label_class_1,
                                                            label_class_2=label_class_2))
        initial_dist = initial_dist.reduce(tf.zeros((bins * 2), dtype=tf.int64), lambda x, y: x + tf.one_hot(y, bins * 2, dtype=tf.int64))
        min_count = tf.reduce_min(initial_dist).numpy()
        min_count = tf.cast(min_count, dtype=tf.int64)
        initial_dist = tf.cast(initial_dist, dtype=tf.float64)
        print(f'Min count: {min_count:,}')
        print(f'Estimated number of samples: {min_count * bins * 2:,}')
        initial_dist = initial_dist / tf.reduce_sum(initial_dist)
        print(initial_dist)
    else:
        initial_dist = None
        
    label_bin_fn = get_label_bin_fn(bins=bins,
                                    lower_var_limit=lower_var_limit,
                                    upper_var_limit=upper_var_limit,
                                    variable=variable,
                                    log_binning_base=log_binning_base,
                                    label_variable=label_variable,
                                    label_class_1=label_class_1,
                                    label_class_2=label_class_2)
        
    if from_min_count:
        return rasample_from_min_bin_count(dataset, label_bin_fn, min_bin_count=min_count, bins=bins * 2)
    else:
        dataset = dataset.rejection_resample(label_bin_fn, target_dist=dist, initial_dist=initial_dist)
        return dataset.map(take_second)
def take_first(x, y)
Expand source code
@tf.function
def take_first(x, y):
    return x
def take_second(x, y)
Expand source code
@tf.function
def take_second(x, y):
    return y
def write_new_variable(variable_value: int, variable_name: str = 'JZ_slice') ‑> Callable[[dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]], dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor]]]

Get a function that writes a new variable to a dataset. The intended use is to use this function with tf.data.Dataset.map() to add a new variable to a dataset. The value is the same for all samples.

Args

variable_value : int
The value of the new variable.
variable_name : str, optional
The name of the new variable. Defaults to 'JZ_slice'.

Returns

Callable[[ROOTVariables], ROOTVariables]
Function that writes a new variable to a dataset.
Expand source code
def write_new_variable(variable_value: int, variable_name: str = 'JZ_slice') -> Callable[[ROOTVariables], ROOTVariables]:
    """Get a function that writes a new variable to a dataset. The intended use is to use this function
    with tf.data.Dataset.map() to add a new variable to a dataset. The value is the same for all samples.

    Args:
        variable_value (int): The value of the new variable.
        variable_name (str, optional): The name of the new variable. Defaults to 'JZ_slice'.

    Returns:
        Callable[[ROOTVariables], ROOTVariables]: Function that writes a new variable to a dataset.
    """

    @tf.function
    def _write(data: ROOTVariables) -> ROOTVariables:
        new_data = data.copy()
        new_data[variable_name] = variable_value
        return new_data
    return _write