Module jidenn.preprocess.plotter

Expand source code
import tensorflow_datasets as tfds
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt

def plot_pt_dist(dataset: tf.data.Dataset, save_path: str = 'figs.png') -> None:

    @tf.function
    def get_labeled_pt(data):
        parton = data['jets_PartonTruthLabelID']
        if tf.equal(parton, tf.constant(21)):
            label = 0
        elif tf.reduce_any(tf.equal(parton, tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.int32))):
            label = 1
        else:
            label = -1
        return {'jets_pt': data['jets_pt'], 'label': label, 'all_label': parton}

    dataset = dataset.map(get_labeled_pt)
    df = tfds.as_dataframe(dataset)
    df['Truth Label'] = df['label'].replace({1: 'quark', 0: 'gluon'})
    df['All Truth Label'] = df['all_label'].replace({1: 'd', 2: 'u', 3: 's', 4: 'c', 5: 'b', 6: 't', 21: 'g'})
    sns.histplot(data=df, x='jets_pt', hue='Truth Label',
                 stat='count', element="step", fill=True,
                 palette='Set1', common_norm=False, hue_order=['quark', 'gluon'])
    plt.savefig(save_path)
    plt.yscale('log')
    plt.savefig(save_path.replace('.png', '_log.png'))
    plt.close()

    sns.histplot(data=df, x='jets_pt', hue='All Truth Label',
                 stat='count', element="step", fill=True, multiple='stack',
                 palette='Set1', common_norm=False)
    plt.savefig(save_path.replace('.png', '_all.png'))
    plt.yscale('log')
    plt.savefig(save_path.replace('.png', '_all_log.png'))
    plt.close()

Functions

def plot_pt_dist(dataset: tensorflow.python.data.ops.dataset_ops.DatasetV2, save_path: str = 'figs.png') ‑> None
Expand source code
def plot_pt_dist(dataset: tf.data.Dataset, save_path: str = 'figs.png') -> None:

    @tf.function
    def get_labeled_pt(data):
        parton = data['jets_PartonTruthLabelID']
        if tf.equal(parton, tf.constant(21)):
            label = 0
        elif tf.reduce_any(tf.equal(parton, tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.int32))):
            label = 1
        else:
            label = -1
        return {'jets_pt': data['jets_pt'], 'label': label, 'all_label': parton}

    dataset = dataset.map(get_labeled_pt)
    df = tfds.as_dataframe(dataset)
    df['Truth Label'] = df['label'].replace({1: 'quark', 0: 'gluon'})
    df['All Truth Label'] = df['all_label'].replace({1: 'd', 2: 'u', 3: 's', 4: 'c', 5: 'b', 6: 't', 21: 'g'})
    sns.histplot(data=df, x='jets_pt', hue='Truth Label',
                 stat='count', element="step", fill=True,
                 palette='Set1', common_norm=False, hue_order=['quark', 'gluon'])
    plt.savefig(save_path)
    plt.yscale('log')
    plt.savefig(save_path.replace('.png', '_log.png'))
    plt.close()

    sns.histplot(data=df, x='jets_pt', hue='All Truth Label',
                 stat='count', element="step", fill=True, multiple='stack',
                 palette='Set1', common_norm=False)
    plt.savefig(save_path.replace('.png', '_all.png'))
    plt.yscale('log')
    plt.savefig(save_path.replace('.png', '_all_log.png'))
    plt.close()