Module jidenn.evaluation.evaluation_metrics

Module containing custom metrics for evaluating models mainly used in HEP applications.

Module containing custom metrics for evaluating models mainly used in HEP applications. 
from typing import List, Dict, Literal, Optional, Union, Tuple
import tensorflow as tf
# import tensorflow_probability as tfp
import numpy as np

class BinaryEfficiency(tf.keras.metrics.Metric):
    r"""Binary Efficiency metric.
    It is defined as
    where $T_i$ is the number of correctly classified data of class $i$ 
    and $F_i$ is the number of incorrectly classified data of class $i$.

    If $i = 1$, then it is the efficiency is called true positive rate (TPR).

        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        name (str, optional): The name of the metric. Defaults to 'efficiency'.

    def __init__(self, label_id: Literal[0, 1] = 1, threshold=0.5, name='efficiency'):
        super(BinaryEfficiency, self).__init__(name=name) = self.add_weight(name='tp', initializer='zeros') = self.add_weight(name='total', initializer='zeros')
        self.label_id = label_id
        self.threshold = threshold

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = tf.logical_not(y_pred)

        values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_weights(sample_weight, values)
            values = tf.multiply(values, sample_weight), self.dtype)))

    def result(self):
        return /

    def reset_state(self):

    def get_config(self):
        config = super(BinaryEfficiency, self).get_config()
        config.update({'threshold': self.threshold, 'label_id': self.label_id})
        return config

class BinaryRejection(tf.keras.metrics.Metric):
    r"""Binary Rejection metric.
    It is defined as
    where $T_i$ is the number of correctly classified data of class $i$ 
    and $F_i$ is the number of incorrectly classified data of class $i$.

        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        name (str, optional): The name of the metric. Defaults to 'rejection'.

    def __init__(self, label_id: Literal[0, 1] = 1, threshold=0.5, name='rejection'):
        super(BinaryRejection, self).__init__(name=name) = self.add_weight(name='tp', initializer='zeros') = self.add_weight(name='total', initializer='zeros')
        self.label_id = label_id
        self.threshold = threshold

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = tf.logical_not(y_pred)

        values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_weights(sample_weight, values)
            values = tf.multiply(values, sample_weight), self.dtype)))

    def result(self):
        tpr = /
        return 1.0 / (1 - tpr)

    def reset_state(self):

    def get_config(self):
        config = super(BinaryRejection, self).get_config()
        config.update({'threshold': self.threshold, 'label_id': self.label_id})
        return config

class RejectionAtEfficiency(tf.keras.metrics.SpecificityAtSensitivity):
    """Rejection at efficiency metric.
    in this case the threshold is chosen such that the efficiency of one class is equal to the given `efficiency`.

        efficiency (float, optional): The efficiency. Defaults to 0.5.
        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        name (str, optional): The name of the metric. Defaults to 'rejection_at_efficiency'.

    def __init__(self, efficiency: float = 0.5, label_id: Literal[0, 1] = 1, name='rejection_at_efficiency'):
        super(RejectionAtEfficiency, self).__init__(
            name=name, sensitivity=efficiency)
        self.label_id = label_id
        self.efficiency = efficiency

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = 1 - y_pred
        super(RejectionAtEfficiency, self).update_state(
            y_true, y_pred, sample_weight=sample_weight)

    def get_config(self):
        config = super(RejectionAtEfficiency, self).get_config()
        config.update({'efficiency': self.efficiency,
                      'label_id': self.label_id})
        return config

    def result(self):
        return 1 / super(RejectionAtEfficiency, self).result()

class EffectiveTaggingEfficiency(tf.keras.metrics.Metric):

    def __init__(self, bins: List[float] = [0, 0.1, 0.25, 0.5, 0.625, 0.75, 0.875, 1], threshold: float = 0.5, name='eff_tag_efficiency', **kwargs):
        super(EffectiveTaggingEfficiency, self).__init__(name=name, **kwargs)
        for value in bins:
            if value < 0 or value > 1:
                raise ValueError('The bins must be between 0 and 1.')
        self.bins = bins
        self.threshold = threshold
        self.bin_sums = self.add_weight(
            name='bin_sums', initializer='zeros', shape=(len(bins) - 1,))
        self.bin_counts = self.add_weight(
            name='bin_counts', initializer='zeros', shape=(len(bins) - 1,))

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], score: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the metric.

            y_true (Union[tf.Tensor, np.ndarray]): The true labels.
            score (Union[tf.Tensor, np.ndarray]): The output of the model, i.e. number **between 0 and 1**.
        score = tf.squeeze(score)
        y_true = tf.squeeze(y_true)
        dilusion_factor = tf.abs(2 * score - 1)
        indicies = tf.searchsorted(self.bins, dilusion_factor)
        pred = tf.cast(score > 0.5, tf.bool)
        wrong_tag = tf.cast(tf.not_equal(
            pred, tf.cast(y_true, tf.bool)), tf.float32)

        bin_counts = tf.cast(tf.math.bincount(
            indicies, minlength=len(self.bins)), tf.float32)
        bin_counts = bin_counts[1:]
        bin_sums = tf.math.bincount(
            indicies, weights=wrong_tag, minlength=len(self.bins))
        bin_sums = bin_sums[1:]

    def get_config(self):
        config = super(EffectiveTaggingEfficiency, self).get_config()
        config.update({'bins': self.bins, 'threshold': self.threshold})
        return config

    def result(self):
        mask = tf.where(self.bin_counts > 0, True, False)
        bin_counts = tf.boolean_mask(self.bin_counts, mask)
        bin_sums = tf.boolean_mask(self.bin_sums, mask)
        binned_wrong_tag_fraction = bin_sums / bin_counts
        eff = bin_counts / tf.reduce_sum(bin_counts)
        eff_tag_eff = tf.reduce_sum(
            eff * (1 - 2 * binned_wrong_tag_fraction)**2)
        return eff_tag_eff

    def reset_state(self):

class FixedWorkingPointBase(tf.keras.metrics.Metric):
    Base class for metrics that calculate the efficiencies and threshold at a fixed working point, 
    i.e. one fixed efficiency with variables threshold.

        working_point (float): The working point, i.e. the efficiency at which the threshold is calculated.
        num_thresholds (int): The number of thresholds to calculate for finding the threshold at the working point.
        name (str): The name of the metric. 
        dtype (tf.dtypes.DType): The data type of the metric.       

    def __init__(self, working_point: float = 0.5,
                 num_thresholds: int = 200, name: Optional[str] = None,
                 dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(name=name, dtype=dtype)
        self.working_point = working_point
        self.num_thresholds = num_thresholds

        self.true_positives = self.add_weight(name='true_positives', initializer='zeros', shape=(num_thresholds,))
        self.total_positives = self.add_weight(name='total_positives', initializer='zeros', shape=(1,))
        self.false_negatives = self.add_weight(name='false_negatives', initializer='zeros', shape=(num_thresholds,))
        self.total_negatives = self.add_weight(name='total_negatives', initializer='zeros', shape=(1,))

        thresholds = [
            (i + 1) * 1.0 / (num_thresholds - 1)
            for i in range(num_thresholds - 2)
        self.thresholds = [0.0] + thresholds + [1.0]

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Optional[Union[tf.Tensor, np.ndarray]] = None):
        """Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold.
        The total positives and total negatives are calculated once.

            y_true (Union[tf.Tensor, np.ndarray]): True labels.
            y_pred (Union[tf.Tensor, np.ndarray]): Predicted scores, i.e. the output of the model.
            sample_weight (Optional[Union[tf.Tensor, np.ndarray]], optional): Sample weights. Defaults to None.
        y_pred = tf.cast(tf.expand_dims(y_pred, axis=1) > self.thresholds, tf.bool)
        y_true = tf.expand_dims(y_true, axis=1)
        positive_equals = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        negative_equals = tf.logical_and(tf.equal(y_true, False), tf.equal(y_pred, False))
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.expand_dims(sample_weight, axis=1)
            positive_equals = tf.multiply(tf.cast(positive_equals, tf.float32), sample_weight)
            negative_equals = tf.multiply(tf.cast(negative_equals, tf.float32), sample_weight)

        self.true_positives.assign_add(tf.reduce_sum(tf.cast(positive_equals, tf.float32), axis=0))
        self.false_negatives.assign_add(tf.reduce_sum(tf.cast(negative_equals, tf.float32), axis=0))

        positives = tf.cast(tf.equal(y_true, True), tf.float32)
        positives = tf.multiply(positives, sample_weight) if sample_weight is not None else positives
        negatives = tf.cast(tf.equal(y_true, False), tf.float32)
        negatives = tf.multiply(negatives, sample_weight) if sample_weight is not None else negatives
        self.total_positives.assign_add(tf.reduce_sum(positives, axis=0))
        self.total_negatives.assign_add(tf.reduce_sum(negatives, axis=0))

    def result(self) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """Calculates the efficiencies and threshold at the working point.
        Both the fixed and variable efficiencies are calculated and returned
        to allow a check of the working point.

            Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: The fixed efficiency, the variable efficiency and the threshold at the working point.

        efficiency_positivess = self.true_positives / self.total_positives
        efficiency_negatives = self.false_negatives / self.total_negatives

        result_index = self._find_index_of_threshold(efficiency_positivess, self.working_point)
        positive_at_wp = tf.gather(efficiency_positivess, result_index)
        negative_at_wp = tf.gather(efficiency_negatives, result_index)
        threshold_at_wp = tf.gather(self.thresholds, result_index)

        return positive_at_wp, negative_at_wp, threshold_at_wp

    def _find_index_of_threshold(self, efficiencies: tf.Tensor, working_point: float) -> tf.Tensor:
        """Finds the index of the threshold at the working point. """
        result_index = tf.math.squared_difference(efficiencies, working_point)
        result_index = tf.argmin(result_index, axis=0)
        return result_index

    def reset_state(self):
        """Resets the confusion matrix statistics."""

class EfficiencyAtFixedWorkingPoint(FixedWorkingPointBase):
    """Calculates the efficiency at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        returned_label_id (Literal[0, 1], optional): The label id whose efficiency is returned. Defaults to 0.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0,
                 num_thresholds: int = 200, name: Optional[str] = 'efficiency_at_fixed_wp', dtype: Optional[tf.dtypes.DType] = None):

        super().__init__(working_point=working_point, num_thresholds=num_thresholds, name=name, dtype=dtype)
        self.fixed_label_id = fixed_label_id
        self.returned_label_id = returned_label_id

    def result(self) -> tf.Tensor:
        """Calculates the efficiency at the working point.

            ValueError: If the fixed_label_id is not 0 or 1.

            tf.Tensor: The efficiency at the working point.
        efficiency_positivess = self.true_positives / self.total_positives
        efficiency_negatives = self.false_negatives / self.total_negatives

        if self.fixed_label_id == 1:
            fixed_efficiency = efficiency_positivess
        elif self.fixed_label_id == 0:
            fixed_efficiency = efficiency_negatives
            raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

        closest_index = self._find_index_of_threshold(fixed_efficiency, self.working_point)

        if self.returned_label_id == 1:
            efficiency_at_wp = tf.gather(efficiency_positivess, closest_index)
        elif self.returned_label_id == 0:
            efficiency_at_wp = tf.gather(efficiency_negatives, closest_index)

        return efficiency_at_wp

class RejectionAtFixedWorkingPoint(EfficiencyAtFixedWorkingPoint):
    """Calculates the rejection at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        returned_label_id (Literal[0, 1], optional): The label id whose efficiency is returned. Defaults to 0.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0,
                 num_thresholds: int = 200, name: Optional[str] = 'rejection_at_fixed_wp', dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(working_point=working_point, fixed_label_id=fixed_label_id, returned_label_id=returned_label_id,
                         num_thresholds=num_thresholds, name=name, dtype=dtype)

    def result(self) -> tf.Tensor:
        """Calculates the rejection at the working point.

            tf.Tensor: The rejection at the working point.

        return 1.0 / (1 - super().result())

class ThresholdAtFixedWorkingPoint(FixedWorkingPointBase):
    """Calculates the threshold at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'threshold_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, num_thresholds: int = 200, fixed_label_id: Literal[0, 1] = 1, name: Optional[str] = 'threshold_at_fixed_efficiency', dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(working_point=working_point, num_thresholds=num_thresholds, name=name, dtype=dtype)
        self.fixed_label_id = fixed_label_id

    def result(self) -> tf.Tensor:
        """Calculates the threshold at the working point.

            ValueError: If the fixed_label_id is not 0 or 1.

            tf.Tensor: The threshold at the working point.

        if self.fixed_label_id == 1:
            efficiencies = self.true_positives / self.total_positives
        elif self.fixed_label_id == 0:
            efficiencies = self.false_negatives / self.total_negatives
            raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

        closest_index = self._find_index_of_threshold(efficiencies, self.working_point)

        return tf.gather(self.thresholds, closest_index)

def get_metrics(threshold: float = 0.5) -> List[tf.keras.metrics.Metric]:
    """Returns a list of metrics.

        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.

        List[tf.keras.metrics.Metric]: The list of selected metrics.
    metrics = [
            name='binary_accuracy', threshold=threshold),
                         label_id=0, threshold=threshold),
                         label_id=1, threshold=threshold),
                        label_id=0, threshold=threshold),
                        label_id=1, threshold=threshold),
                                      fixed_label_id=1, working_point=0.5, returned_label_id=0),
                                      fixed_label_id=1, working_point=0.5, returned_label_id=1),
                                     fixed_label_id=1, working_point=0.5, returned_label_id=0),
                                     fixed_label_id=1, working_point=0.5),
                                      fixed_label_id=1, working_point=0.8, returned_label_id=0),
                                      fixed_label_id=1, working_point=0.8, returned_label_id=1),
                                     fixed_label_id=1, working_point=0.8, returned_label_id=0),
                                     fixed_label_id=1, working_point=0.8),
            name='effective_tagging_efficiency', threshold=threshold),
    return metrics

def calculate_metrics(y_true: np.ndarray, score: np.ndarray, threshold: float = 0.5, weights: Optional[np.ndarray] = None) -> Dict[str, float]:
    """Calculates the metrics.
        y_true (np.ndarray): The true labels.
        score (np.ndarray): The output scores of the model.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        weights (np.ndarray, optional): The sample weights. Defaults to None.
        Dict[str, float]: The dictionary of metric names and values.
    metrics = get_metrics(threshold)
    results = {}
    for metric in metrics:
        metric.update_state(y_true, score, sample_weight=weights)
        result = metric.result().numpy()
        results[] = result
    return results


def calculate_metrics(y_true: numpy.ndarray, score: numpy.ndarray, threshold: float = 0.5, weights: Optional[numpy.ndarray] = None) ‑> Dict[str, float]

Calculates the metrics.


y_true : np.ndarray
The true labels.
score : np.ndarray
The output scores of the model.
threshold : float, optional
The threshold for the prediction. Defaults to 0.5.
weights : np.ndarray, optional
The sample weights. Defaults to None.


Dict[str, float]
The dictionary of metric names and values.
def calculate_metrics(y_true: np.ndarray, score: np.ndarray, threshold: float = 0.5, weights: Optional[np.ndarray] = None) -> Dict[str, float]:
    """Calculates the metrics.
        y_true (np.ndarray): The true labels.
        score (np.ndarray): The output scores of the model.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        weights (np.ndarray, optional): The sample weights. Defaults to None.
        Dict[str, float]: The dictionary of metric names and values.
    metrics = get_metrics(threshold)
    results = {}
    for metric in metrics:
        metric.update_state(y_true, score, sample_weight=weights)
        result = metric.result().numpy()
        results[] = result
    return results
def get_metrics(threshold: float = 0.5) ‑> List[keras.metrics.base_metric.Metric]

Returns a list of metrics.


threshold : float, optional
The threshold for the prediction. Defaults to 0.5.


The list of selected metrics.
def get_metrics(threshold: float = 0.5) -> List[tf.keras.metrics.Metric]:
    """Returns a list of metrics.

        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.

        List[tf.keras.metrics.Metric]: The list of selected metrics.
    metrics = [
            name='binary_accuracy', threshold=threshold),
                         label_id=0, threshold=threshold),
                         label_id=1, threshold=threshold),
                        label_id=0, threshold=threshold),
                        label_id=1, threshold=threshold),
                                      fixed_label_id=1, working_point=0.5, returned_label_id=0),
                                      fixed_label_id=1, working_point=0.5, returned_label_id=1),
                                     fixed_label_id=1, working_point=0.5, returned_label_id=0),
                                     fixed_label_id=1, working_point=0.5),
                                      fixed_label_id=1, working_point=0.8, returned_label_id=0),
                                      fixed_label_id=1, working_point=0.8, returned_label_id=1),
                                     fixed_label_id=1, working_point=0.8, returned_label_id=0),
                                     fixed_label_id=1, working_point=0.8),
            name='effective_tagging_efficiency', threshold=threshold),
    return metrics


class BinaryEfficiency (label_id: Literal[0, 1] = 1, threshold=0.5, name='efficiency')

Binary Efficiency metric. It is defined as \varepsilon_i=\frac{T_i}{T_i+F_i} where $T_i$ is the number of correctly classified data of class $i$ and $F_i$ is the number of incorrectly classified data of class $i$.

If $i = 1$, then it is the efficiency is called true positive rate (TPR).


label_id : int, optional
The label id for which the efficiency is calculated. Defaults to 1.
threshold : float, optional
The threshold for the prediction. Defaults to 0.5.
name : str, optional
The name of the metric. Defaults to 'efficiency'.
class BinaryEfficiency(tf.keras.metrics.Metric):
    r"""Binary Efficiency metric.
    It is defined as
    where $T_i$ is the number of correctly classified data of class $i$ 
    and $F_i$ is the number of incorrectly classified data of class $i$.

    If $i = 1$, then it is the efficiency is called true positive rate (TPR).

        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        name (str, optional): The name of the metric. Defaults to 'efficiency'.

    def __init__(self, label_id: Literal[0, 1] = 1, threshold=0.5, name='efficiency'):
        super(BinaryEfficiency, self).__init__(name=name) = self.add_weight(name='tp', initializer='zeros') = self.add_weight(name='total', initializer='zeros')
        self.label_id = label_id
        self.threshold = threshold

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = tf.logical_not(y_pred)

        values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_weights(sample_weight, values)
            values = tf.multiply(values, sample_weight), self.dtype)))

    def result(self):
        return /

    def reset_state(self):

    def get_config(self):
        config = super(BinaryEfficiency, self).get_config()
        config.update({'threshold': self.threshold, 'label_id': self.label_id})
        return config


  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def get_config(self)

Returns the serializable config of the metric.

def get_config(self):
    config = super(BinaryEfficiency, self).get_config()
    config.update({'threshold': self.threshold, 'label_id': self.label_id})
    return config
def reset_state(self)

Resets all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

def reset_state(self):
def result(self)

Computes and returns the scalar metric value tensor or a dict of scalars.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.


A scalar tensor, or a dictionary of scalar tensors.

def result(self):
    return /
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray] = None)

Accumulates the efficiency.


y_true : tf.Tensor
The true labels.
y_pred : tf.Tensor
The predicted labels.
sample_weight : tf.Tensor, optional
The sample weights. Defaults to None.
def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
    """Accumulates the efficiency.

        y_true (tf.Tensor): The true labels.
        y_pred (tf.Tensor): The predicted labels.
        sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
    y_pred = tf.cast(y_pred > self.threshold, tf.bool)
    y_true = tf.cast(y_true, tf.bool)
    if self.label_id is not None and self.label_id != 1:
        y_true = tf.logical_not(y_true)
        y_pred = tf.logical_not(y_pred)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, self.dtype)
        sample_weight = tf.broadcast_weights(sample_weight, values)
        values = tf.multiply(values, sample_weight), self.dtype)))
class BinaryRejection (label_id: Literal[0, 1] = 1, threshold=0.5, name='rejection')

Binary Rejection metric. It is defined as \varepsilon_i^{-1}=\frac{T_i+F_i}{T_i} where $T_i$ is the number of correctly classified data of class $i$ and $F_i$ is the number of incorrectly classified data of class $i$.


label_id : int, optional
The label id for which the efficiency is calculated. Defaults to 1.
threshold : float, optional
The threshold for the prediction. Defaults to 0.5.
name : str, optional
The name of the metric. Defaults to 'rejection'.
class BinaryRejection(tf.keras.metrics.Metric):
    r"""Binary Rejection metric.
    It is defined as
    where $T_i$ is the number of correctly classified data of class $i$ 
    and $F_i$ is the number of incorrectly classified data of class $i$.

        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        threshold (float, optional): The threshold for the prediction. Defaults to 0.5.
        name (str, optional): The name of the metric. Defaults to 'rejection'.

    def __init__(self, label_id: Literal[0, 1] = 1, threshold=0.5, name='rejection'):
        super(BinaryRejection, self).__init__(name=name) = self.add_weight(name='tp', initializer='zeros') = self.add_weight(name='total', initializer='zeros')
        self.label_id = label_id
        self.threshold = threshold

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = tf.logical_not(y_pred)

        values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        values = tf.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_weights(sample_weight, values)
            values = tf.multiply(values, sample_weight), self.dtype)))

    def result(self):
        tpr = /
        return 1.0 / (1 - tpr)

    def reset_state(self):

    def get_config(self):
        config = super(BinaryRejection, self).get_config()
        config.update({'threshold': self.threshold, 'label_id': self.label_id})
        return config


  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def get_config(self)

Returns the serializable config of the metric.

def get_config(self):
    config = super(BinaryRejection, self).get_config()
    config.update({'threshold': self.threshold, 'label_id': self.label_id})
    return config
def reset_state(self)

Resets all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

def reset_state(self):
def result(self)

Computes and returns the scalar metric value tensor or a dict of scalars.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.


A scalar tensor, or a dictionary of scalar tensors.

def result(self):
    tpr = /
    return 1.0 / (1 - tpr)
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray] = None)

Accumulates the efficiency.


y_true : tf.Tensor
The true labels.
y_pred : tf.Tensor
The predicted labels.
sample_weight : tf.Tensor, optional
The sample weights. Defaults to None.
def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
    """Accumulates the efficiency.

        y_true (tf.Tensor): The true labels.
        y_pred (tf.Tensor): The predicted labels.
        sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
    y_pred = tf.cast(y_pred > self.threshold, tf.bool)
    y_true = tf.cast(y_true, tf.bool)
    if self.label_id is not None and self.label_id != 1:
        y_true = tf.logical_not(y_true)
        y_pred = tf.logical_not(y_pred)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, self.dtype)
        sample_weight = tf.broadcast_weights(sample_weight, values)
        values = tf.multiply(values, sample_weight), self.dtype)))
class EffectiveTaggingEfficiency (bins: List[float] = [0, 0.1, 0.25, 0.5, 0.625, 0.75, 0.875, 1], threshold: float = 0.5, name='eff_tag_efficiency', **kwargs)

Encapsulates metric logic and state.


(Optional) string name of the metric instance.
(Optional) data type of the metric result.
Additional layer keywords arguments.

Standalone usage:

m = SomeMetric(...)
for input in ...:
print('Final result: ', m.result().numpy())

Usage with compile() API:

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))


data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

dataset =, labels))
dataset = dataset.batch(32), epochs=10)

To be implemented by subclasses: * __init__(): All state variables should be created in this method by calling self.add_weight() like: self.var = self.add_weight(...) * update_state(): Has all updates to the state variables like: self.var.assign_add(…). * result(): Computes and returns a scalar value or a dict of scalar values for the metric from the state variables.

Example subclass implementation:

class BinaryTruePositives(tf.keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super(BinaryTruePositives, self).__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = tf.cast(sample_weight, self.dtype)
      sample_weight = tf.broadcast_to(sample_weight, values.shape)
      values = tf.multiply(values, sample_weight)

  def result(self):
    return self.true_positives
class EffectiveTaggingEfficiency(tf.keras.metrics.Metric):

    def __init__(self, bins: List[float] = [0, 0.1, 0.25, 0.5, 0.625, 0.75, 0.875, 1], threshold: float = 0.5, name='eff_tag_efficiency', **kwargs):
        super(EffectiveTaggingEfficiency, self).__init__(name=name, **kwargs)
        for value in bins:
            if value < 0 or value > 1:
                raise ValueError('The bins must be between 0 and 1.')
        self.bins = bins
        self.threshold = threshold
        self.bin_sums = self.add_weight(
            name='bin_sums', initializer='zeros', shape=(len(bins) - 1,))
        self.bin_counts = self.add_weight(
            name='bin_counts', initializer='zeros', shape=(len(bins) - 1,))

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], score: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the metric.

            y_true (Union[tf.Tensor, np.ndarray]): The true labels.
            score (Union[tf.Tensor, np.ndarray]): The output of the model, i.e. number **between 0 and 1**.
        score = tf.squeeze(score)
        y_true = tf.squeeze(y_true)
        dilusion_factor = tf.abs(2 * score - 1)
        indicies = tf.searchsorted(self.bins, dilusion_factor)
        pred = tf.cast(score > 0.5, tf.bool)
        wrong_tag = tf.cast(tf.not_equal(
            pred, tf.cast(y_true, tf.bool)), tf.float32)

        bin_counts = tf.cast(tf.math.bincount(
            indicies, minlength=len(self.bins)), tf.float32)
        bin_counts = bin_counts[1:]
        bin_sums = tf.math.bincount(
            indicies, weights=wrong_tag, minlength=len(self.bins))
        bin_sums = bin_sums[1:]

    def get_config(self):
        config = super(EffectiveTaggingEfficiency, self).get_config()
        config.update({'bins': self.bins, 'threshold': self.threshold})
        return config

    def result(self):
        mask = tf.where(self.bin_counts > 0, True, False)
        bin_counts = tf.boolean_mask(self.bin_counts, mask)
        bin_sums = tf.boolean_mask(self.bin_sums, mask)
        binned_wrong_tag_fraction = bin_sums / bin_counts
        eff = bin_counts / tf.reduce_sum(bin_counts)
        eff_tag_eff = tf.reduce_sum(
            eff * (1 - 2 * binned_wrong_tag_fraction)**2)
        return eff_tag_eff

    def reset_state(self):


  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def get_config(self)

Returns the serializable config of the metric.

def get_config(self):
    config = super(EffectiveTaggingEfficiency, self).get_config()
    config.update({'bins': self.bins, 'threshold': self.threshold})
    return config
def reset_state(self)

Resets all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

def reset_state(self):
def result(self)

Computes and returns the scalar metric value tensor or a dict of scalars.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.


A scalar tensor, or a dictionary of scalar tensors.

def result(self):
    mask = tf.where(self.bin_counts > 0, True, False)
    bin_counts = tf.boolean_mask(self.bin_counts, mask)
    bin_sums = tf.boolean_mask(self.bin_sums, mask)
    binned_wrong_tag_fraction = bin_sums / bin_counts
    eff = bin_counts / tf.reduce_sum(bin_counts)
    eff_tag_eff = tf.reduce_sum(
        eff * (1 - 2 * binned_wrong_tag_fraction)**2)
    return eff_tag_eff
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], score: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray] = None)

Accumulates the metric.


y_true : Union[tf.Tensor, np.ndarray]
The true labels.
score : Union[tf.Tensor, np.ndarray]
The output of the model, i.e. number between 0 and 1.
def update_state(self, y_true: Union[tf.Tensor, np.ndarray], score: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
    """Accumulates the metric.

        y_true (Union[tf.Tensor, np.ndarray]): The true labels.
        score (Union[tf.Tensor, np.ndarray]): The output of the model, i.e. number **between 0 and 1**.
    score = tf.squeeze(score)
    y_true = tf.squeeze(y_true)
    dilusion_factor = tf.abs(2 * score - 1)
    indicies = tf.searchsorted(self.bins, dilusion_factor)
    pred = tf.cast(score > 0.5, tf.bool)
    wrong_tag = tf.cast(tf.not_equal(
        pred, tf.cast(y_true, tf.bool)), tf.float32)

    bin_counts = tf.cast(tf.math.bincount(
        indicies, minlength=len(self.bins)), tf.float32)
    bin_counts = bin_counts[1:]
    bin_sums = tf.math.bincount(
        indicies, weights=wrong_tag, minlength=len(self.bins))
    bin_sums = bin_sums[1:]
class EfficiencyAtFixedWorkingPoint (working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0, num_thresholds: int = 200, name: Optional[str] = 'efficiency_at_fixed_wp', dtype: Optional[tensorflow.python.framework.dtypes.DType] = None)

Calculates the efficiency at a fixed working point. The working point is defined by the user.


working_point : float
The working point. Defaults to 0.5.
fixed_label_id : Literal[0, 1], optional
The label id whose efficiency is fixed. Defaults to 1.
returned_label_id : Literal[0, 1], optional
The label id whose efficiency is returned. Defaults to 0.
num_thresholds : int, optional
The number of thresholds to use for the estimation. Defaults to 200.
name : Optional[str], optional
The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
dtype : Optional[tf.dtypes.DType], optional
The data type of the metric. Defaults to None.
class EfficiencyAtFixedWorkingPoint(FixedWorkingPointBase):
    """Calculates the efficiency at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        returned_label_id (Literal[0, 1], optional): The label id whose efficiency is returned. Defaults to 0.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0,
                 num_thresholds: int = 200, name: Optional[str] = 'efficiency_at_fixed_wp', dtype: Optional[tf.dtypes.DType] = None):

        super().__init__(working_point=working_point, num_thresholds=num_thresholds, name=name, dtype=dtype)
        self.fixed_label_id = fixed_label_id
        self.returned_label_id = returned_label_id

    def result(self) -> tf.Tensor:
        """Calculates the efficiency at the working point.

            ValueError: If the fixed_label_id is not 0 or 1.

            tf.Tensor: The efficiency at the working point.
        efficiency_positivess = self.true_positives / self.total_positives
        efficiency_negatives = self.false_negatives / self.total_negatives

        if self.fixed_label_id == 1:
            fixed_efficiency = efficiency_positivess
        elif self.fixed_label_id == 0:
            fixed_efficiency = efficiency_negatives
            raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

        closest_index = self._find_index_of_threshold(fixed_efficiency, self.working_point)

        if self.returned_label_id == 1:
            efficiency_at_wp = tf.gather(efficiency_positivess, closest_index)
        elif self.returned_label_id == 0:
            efficiency_at_wp = tf.gather(efficiency_negatives, closest_index)

        return efficiency_at_wp


  • FixedWorkingPointBase
  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector



def reset_state(self)

Inherited from: FixedWorkingPointBase.reset_state

Resets the confusion matrix statistics.

def result(self) ‑> tensorflow.python.framework.ops.Tensor

Calculates the efficiency at the working point.


If the fixed_label_id is not 0 or 1.


The efficiency at the working point.
def result(self) -> tf.Tensor:
    """Calculates the efficiency at the working point.

        ValueError: If the fixed_label_id is not 0 or 1.

        tf.Tensor: The efficiency at the working point.
    efficiency_positivess = self.true_positives / self.total_positives
    efficiency_negatives = self.false_negatives / self.total_negatives

    if self.fixed_label_id == 1:
        fixed_efficiency = efficiency_positivess
    elif self.fixed_label_id == 0:
        fixed_efficiency = efficiency_negatives
        raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

    closest_index = self._find_index_of_threshold(fixed_efficiency, self.working_point)

    if self.returned_label_id == 1:
        efficiency_at_wp = tf.gather(efficiency_positivess, closest_index)
    elif self.returned_label_id == 0:
        efficiency_at_wp = tf.gather(efficiency_negatives, closest_index)

    return efficiency_at_wp
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray, ForwardRef(None)] = None)

Inherited from: FixedWorkingPointBase.update_state

Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold. The total positives and total …

class FixedWorkingPointBase (working_point: float = 0.5, num_thresholds: int = 200, name: Optional[str] = None, dtype: Optional[tensorflow.python.framework.dtypes.DType] = None)

Base class for metrics that calculate the efficiencies and threshold at a fixed working point, i.e. one fixed efficiency with variables threshold.


working_point : float
The working point, i.e. the efficiency at which the threshold is calculated.
num_thresholds : int
The number of thresholds to calculate for finding the threshold at the working point.
name : str
The name of the metric.
dtype : tf.dtypes.DType
The data type of the metric.
class FixedWorkingPointBase(tf.keras.metrics.Metric):
    Base class for metrics that calculate the efficiencies and threshold at a fixed working point, 
    i.e. one fixed efficiency with variables threshold.

        working_point (float): The working point, i.e. the efficiency at which the threshold is calculated.
        num_thresholds (int): The number of thresholds to calculate for finding the threshold at the working point.
        name (str): The name of the metric. 
        dtype (tf.dtypes.DType): The data type of the metric.       

    def __init__(self, working_point: float = 0.5,
                 num_thresholds: int = 200, name: Optional[str] = None,
                 dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(name=name, dtype=dtype)
        self.working_point = working_point
        self.num_thresholds = num_thresholds

        self.true_positives = self.add_weight(name='true_positives', initializer='zeros', shape=(num_thresholds,))
        self.total_positives = self.add_weight(name='total_positives', initializer='zeros', shape=(1,))
        self.false_negatives = self.add_weight(name='false_negatives', initializer='zeros', shape=(num_thresholds,))
        self.total_negatives = self.add_weight(name='total_negatives', initializer='zeros', shape=(1,))

        thresholds = [
            (i + 1) * 1.0 / (num_thresholds - 1)
            for i in range(num_thresholds - 2)
        self.thresholds = [0.0] + thresholds + [1.0]

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Optional[Union[tf.Tensor, np.ndarray]] = None):
        """Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold.
        The total positives and total negatives are calculated once.

            y_true (Union[tf.Tensor, np.ndarray]): True labels.
            y_pred (Union[tf.Tensor, np.ndarray]): Predicted scores, i.e. the output of the model.
            sample_weight (Optional[Union[tf.Tensor, np.ndarray]], optional): Sample weights. Defaults to None.
        y_pred = tf.cast(tf.expand_dims(y_pred, axis=1) > self.thresholds, tf.bool)
        y_true = tf.expand_dims(y_true, axis=1)
        positive_equals = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
        negative_equals = tf.logical_and(tf.equal(y_true, False), tf.equal(y_pred, False))
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.expand_dims(sample_weight, axis=1)
            positive_equals = tf.multiply(tf.cast(positive_equals, tf.float32), sample_weight)
            negative_equals = tf.multiply(tf.cast(negative_equals, tf.float32), sample_weight)

        self.true_positives.assign_add(tf.reduce_sum(tf.cast(positive_equals, tf.float32), axis=0))
        self.false_negatives.assign_add(tf.reduce_sum(tf.cast(negative_equals, tf.float32), axis=0))

        positives = tf.cast(tf.equal(y_true, True), tf.float32)
        positives = tf.multiply(positives, sample_weight) if sample_weight is not None else positives
        negatives = tf.cast(tf.equal(y_true, False), tf.float32)
        negatives = tf.multiply(negatives, sample_weight) if sample_weight is not None else negatives
        self.total_positives.assign_add(tf.reduce_sum(positives, axis=0))
        self.total_negatives.assign_add(tf.reduce_sum(negatives, axis=0))

    def result(self) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
        """Calculates the efficiencies and threshold at the working point.
        Both the fixed and variable efficiencies are calculated and returned
        to allow a check of the working point.

            Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: The fixed efficiency, the variable efficiency and the threshold at the working point.

        efficiency_positivess = self.true_positives / self.total_positives
        efficiency_negatives = self.false_negatives / self.total_negatives

        result_index = self._find_index_of_threshold(efficiency_positivess, self.working_point)
        positive_at_wp = tf.gather(efficiency_positivess, result_index)
        negative_at_wp = tf.gather(efficiency_negatives, result_index)
        threshold_at_wp = tf.gather(self.thresholds, result_index)

        return positive_at_wp, negative_at_wp, threshold_at_wp

    def _find_index_of_threshold(self, efficiencies: tf.Tensor, working_point: float) -> tf.Tensor:
        """Finds the index of the threshold at the working point. """
        result_index = tf.math.squared_difference(efficiencies, working_point)
        result_index = tf.argmin(result_index, axis=0)
        return result_index

    def reset_state(self):
        """Resets the confusion matrix statistics."""


  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector



def reset_state(self)

Resets the confusion matrix statistics.

def reset_state(self):
    """Resets the confusion matrix statistics."""
def result(self) ‑> Tuple[tensorflow.python.framework.ops.Tensor, tensorflow.python.framework.ops.Tensor, tensorflow.python.framework.ops.Tensor]

Calculates the efficiencies and threshold at the working point. Both the fixed and variable efficiencies are calculated and returned to allow a check of the working point.


Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
The fixed efficiency, the variable efficiency and the threshold at the working point.
def result(self) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Calculates the efficiencies and threshold at the working point.
    Both the fixed and variable efficiencies are calculated and returned
    to allow a check of the working point.

        Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: The fixed efficiency, the variable efficiency and the threshold at the working point.

    efficiency_positivess = self.true_positives / self.total_positives
    efficiency_negatives = self.false_negatives / self.total_negatives

    result_index = self._find_index_of_threshold(efficiency_positivess, self.working_point)
    positive_at_wp = tf.gather(efficiency_positivess, result_index)
    negative_at_wp = tf.gather(efficiency_negatives, result_index)
    threshold_at_wp = tf.gather(self.thresholds, result_index)

    return positive_at_wp, negative_at_wp, threshold_at_wp
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray, ForwardRef(None)] = None)

Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold. The total positives and total negatives are calculated once.


y_true : Union[tf.Tensor, np.ndarray]
True labels.
y_pred : Union[tf.Tensor, np.ndarray]
Predicted scores, i.e. the output of the model.
sample_weight : Optional[Union[tf.Tensor, np.ndarray]], optional
Sample weights. Defaults to None.
def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Optional[Union[tf.Tensor, np.ndarray]] = None):
    """Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold.
    The total positives and total negatives are calculated once.

        y_true (Union[tf.Tensor, np.ndarray]): True labels.
        y_pred (Union[tf.Tensor, np.ndarray]): Predicted scores, i.e. the output of the model.
        sample_weight (Optional[Union[tf.Tensor, np.ndarray]], optional): Sample weights. Defaults to None.
    y_pred = tf.cast(tf.expand_dims(y_pred, axis=1) > self.thresholds, tf.bool)
    y_true = tf.expand_dims(y_true, axis=1)
    positive_equals = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    negative_equals = tf.logical_and(tf.equal(y_true, False), tf.equal(y_pred, False))
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, tf.float32)
        sample_weight = tf.expand_dims(sample_weight, axis=1)
        positive_equals = tf.multiply(tf.cast(positive_equals, tf.float32), sample_weight)
        negative_equals = tf.multiply(tf.cast(negative_equals, tf.float32), sample_weight)

    self.true_positives.assign_add(tf.reduce_sum(tf.cast(positive_equals, tf.float32), axis=0))
    self.false_negatives.assign_add(tf.reduce_sum(tf.cast(negative_equals, tf.float32), axis=0))

    positives = tf.cast(tf.equal(y_true, True), tf.float32)
    positives = tf.multiply(positives, sample_weight) if sample_weight is not None else positives
    negatives = tf.cast(tf.equal(y_true, False), tf.float32)
    negatives = tf.multiply(negatives, sample_weight) if sample_weight is not None else negatives
    self.total_positives.assign_add(tf.reduce_sum(positives, axis=0))
    self.total_negatives.assign_add(tf.reduce_sum(negatives, axis=0))
class RejectionAtEfficiency (efficiency: float = 0.5, label_id: Literal[0, 1] = 1, name='rejection_at_efficiency')

Rejection at efficiency metric. in this case the threshold is chosen such that the efficiency of one class is equal to the given efficiency.


efficiency : float, optional
The efficiency. Defaults to 0.5.
label_id : int, optional
The label id for which the efficiency is calculated. Defaults to 1.
name : str, optional
The name of the metric. Defaults to 'rejection_at_efficiency'.
class RejectionAtEfficiency(tf.keras.metrics.SpecificityAtSensitivity):
    """Rejection at efficiency metric.
    in this case the threshold is chosen such that the efficiency of one class is equal to the given `efficiency`.

        efficiency (float, optional): The efficiency. Defaults to 0.5.
        label_id (int, optional): The label id for which the efficiency is calculated. Defaults to 1.
        name (str, optional): The name of the metric. Defaults to 'rejection_at_efficiency'.

    def __init__(self, efficiency: float = 0.5, label_id: Literal[0, 1] = 1, name='rejection_at_efficiency'):
        super(RejectionAtEfficiency, self).__init__(
            name=name, sensitivity=efficiency)
        self.label_id = label_id
        self.efficiency = efficiency

    def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
        """Accumulates the efficiency.

            y_true (tf.Tensor): The true labels.
            y_pred (tf.Tensor): The predicted labels.
            sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
        y_true = tf.cast(y_true, tf.bool)
        if self.label_id is not None and self.label_id != 1:
            y_true = tf.logical_not(y_true)
            y_pred = 1 - y_pred
        super(RejectionAtEfficiency, self).update_state(
            y_true, y_pred, sample_weight=sample_weight)

    def get_config(self):
        config = super(RejectionAtEfficiency, self).get_config()
        config.update({'efficiency': self.efficiency,
                      'label_id': self.label_id})
        return config

    def result(self):
        return 1 / super(RejectionAtEfficiency, self).result()


  • keras.metrics.metrics.SpecificityAtSensitivity
  • keras.metrics.metrics.SensitivitySpecificityBase
  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def get_config(self)

Returns the serializable config of the metric.

def get_config(self):
    config = super(RejectionAtEfficiency, self).get_config()
    config.update({'efficiency': self.efficiency,
                  'label_id': self.label_id})
    return config
def result(self)

Computes and returns the scalar metric value tensor or a dict of scalars.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.


A scalar tensor, or a dictionary of scalar tensors.

def result(self):
    return 1 / super(RejectionAtEfficiency, self).result()
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray] = None)

Accumulates the efficiency.


y_true : tf.Tensor
The true labels.
y_pred : tf.Tensor
The predicted labels.
sample_weight : tf.Tensor, optional
The sample weights. Defaults to None.
def update_state(self, y_true: Union[tf.Tensor, np.ndarray], y_pred: Union[tf.Tensor, np.ndarray], sample_weight: Union[tf.Tensor, np.ndarray] = None):
    """Accumulates the efficiency.

        y_true (tf.Tensor): The true labels.
        y_pred (tf.Tensor): The predicted labels.
        sample_weight (tf.Tensor, optional): The sample weights. Defaults to None.
    y_true = tf.cast(y_true, tf.bool)
    if self.label_id is not None and self.label_id != 1:
        y_true = tf.logical_not(y_true)
        y_pred = 1 - y_pred
    super(RejectionAtEfficiency, self).update_state(
        y_true, y_pred, sample_weight=sample_weight)
class RejectionAtFixedWorkingPoint (working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0, num_thresholds: int = 200, name: Optional[str] = 'rejection_at_fixed_wp', dtype: Optional[tensorflow.python.framework.dtypes.DType] = None)

Calculates the rejection at a fixed working point. The working point is defined by the user.


working_point : float
The working point. Defaults to 0.5.
fixed_label_id : Literal[0, 1], optional
The label id whose efficiency is fixed. Defaults to 1.
returned_label_id : Literal[0, 1], optional
The label id whose efficiency is returned. Defaults to 0.
num_thresholds : int, optional
The number of thresholds to use for the estimation. Defaults to 200.
name : Optional[str], optional
The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
dtype : Optional[tf.dtypes.DType], optional
The data type of the metric. Defaults to None.
class RejectionAtFixedWorkingPoint(EfficiencyAtFixedWorkingPoint):
    """Calculates the rejection at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        returned_label_id (Literal[0, 1], optional): The label id whose efficiency is returned. Defaults to 0.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'efficiency_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, fixed_label_id: Literal[0, 1] = 1, returned_label_id: Literal[0, 1] = 0,
                 num_thresholds: int = 200, name: Optional[str] = 'rejection_at_fixed_wp', dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(working_point=working_point, fixed_label_id=fixed_label_id, returned_label_id=returned_label_id,
                         num_thresholds=num_thresholds, name=name, dtype=dtype)

    def result(self) -> tf.Tensor:
        """Calculates the rejection at the working point.

            tf.Tensor: The rejection at the working point.

        return 1.0 / (1 - super().result())


  • EfficiencyAtFixedWorkingPoint
  • FixedWorkingPointBase
  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def reset_state(self)

Inherited from: EfficiencyAtFixedWorkingPoint.reset_state

Resets the confusion matrix statistics.

def result(self) ‑> tensorflow.python.framework.ops.Tensor

Calculates the rejection at the working point.


The rejection at the working point.
def result(self) -> tf.Tensor:
    """Calculates the rejection at the working point.

        tf.Tensor: The rejection at the working point.

    return 1.0 / (1 - super().result())
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray, ForwardRef(None)] = None)

Inherited from: EfficiencyAtFixedWorkingPoint.update_state

Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold. The total positives and total …

class ThresholdAtFixedWorkingPoint (working_point: float = 0.5, num_thresholds: int = 200, fixed_label_id: Literal[0, 1] = 1, name: Optional[str] = 'threshold_at_fixed_efficiency', dtype: Optional[tensorflow.python.framework.dtypes.DType] = None)

Calculates the threshold at a fixed working point. The working point is defined by the user.


working_point : float
The working point. Defaults to 0.5.
fixed_label_id : Literal[0, 1], optional
The label id whose efficiency is fixed. Defaults to 1.
num_thresholds : int, optional
The number of thresholds to use for the estimation. Defaults to 200.
name : Optional[str], optional
The name of the metric. Defaults to 'threshold_at_fixed_wp'.
dtype : Optional[tf.dtypes.DType], optional
The data type of the metric. Defaults to None.
class ThresholdAtFixedWorkingPoint(FixedWorkingPointBase):
    """Calculates the threshold at a fixed working point. The working point is defined by the user.

        working_point (float): The working point. Defaults to 0.5.
        fixed_label_id (Literal[0, 1], optional): The label id whose efficiency is fixed. Defaults to 1.
        num_thresholds (int, optional): The number of thresholds to use for the estimation. Defaults to 200.
        name (Optional[str], optional): The name of the metric. Defaults to 'threshold_at_fixed_wp'.
        dtype (Optional[tf.dtypes.DType], optional): The data type of the metric. Defaults to None.

    def __init__(self, working_point: float = 0.5, num_thresholds: int = 200, fixed_label_id: Literal[0, 1] = 1, name: Optional[str] = 'threshold_at_fixed_efficiency', dtype: Optional[tf.dtypes.DType] = None):
        super().__init__(working_point=working_point, num_thresholds=num_thresholds, name=name, dtype=dtype)
        self.fixed_label_id = fixed_label_id

    def result(self) -> tf.Tensor:
        """Calculates the threshold at the working point.

            ValueError: If the fixed_label_id is not 0 or 1.

            tf.Tensor: The threshold at the working point.

        if self.fixed_label_id == 1:
            efficiencies = self.true_positives / self.total_positives
        elif self.fixed_label_id == 0:
            efficiencies = self.false_negatives / self.total_negatives
            raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

        closest_index = self._find_index_of_threshold(efficiencies, self.working_point)

        return tf.gather(self.thresholds, closest_index)


  • FixedWorkingPointBase
  • keras.metrics.base_metric.Metric
  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector


def reset_state(self)

Inherited from: FixedWorkingPointBase.reset_state

Resets the confusion matrix statistics.

def result(self) ‑> tensorflow.python.framework.ops.Tensor

Calculates the threshold at the working point.


If the fixed_label_id is not 0 or 1.


The threshold at the working point.
def result(self) -> tf.Tensor:
    """Calculates the threshold at the working point.

        ValueError: If the fixed_label_id is not 0 or 1.

        tf.Tensor: The threshold at the working point.

    if self.fixed_label_id == 1:
        efficiencies = self.true_positives / self.total_positives
    elif self.fixed_label_id == 0:
        efficiencies = self.false_negatives / self.total_negatives
        raise ValueError(f'fixed_label_id must be 0 or 1, but is {self.fixed_label_id}')

    closest_index = self._find_index_of_threshold(efficiencies, self.working_point)

    return tf.gather(self.thresholds, closest_index)
def update_state(self, y_true: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], y_pred: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray], sample_weight: Union[tensorflow.python.framework.ops.Tensor, numpy.ndarray, ForwardRef(None)] = None)

Inherited from: FixedWorkingPointBase.update_state

Accumulates the confusion matrix statistics. The true positives and false negatives are calculated for each threshold. The total positives and total …