Live normalized stacked area chart in Google Colab

343 Views Asked by At

I am training a neural network for binary classification on Google Colab. Each epoch, I evaluate it on the validation dataset and calculate the percentages of true positives, false positives, true negatives, and false negatives. I want to see a live normalized stacked area chart (see https://altair-viz.github.io/gallery/normalized_stacked_area_chart.html for an explanation of what that is) of these four numbers. It should get updated with each epoch as the training process goes on. How do I achieve this? I am ready to use any 3rd party library.

1

There are 1 best solutions below

2
On

Solution

I assume that you are saving your target metric after every epoch.

  • tf: True Positive
  • tn: True Negative
  • fp: False Positive
  • fn: False Negative

Note: Ideally, you would create a list of dictionaries, where each dictionary consists of the metrics you want and the corresponding epoch number as follows:

# call this list results
results = [ 
  ...
  {"epoch": 0, "tp": 200, "tn": 80, "fp": 18, "fn": 5}
  ... 
]

Use this list to create a pandas dataframe and then use the custom plotting function as shown below.

df = pf.DataFrame(results).T

The solution below uses Plotly library to create the desired chart. Here is a jupyter notebook (with a google colab link) to quickly check-out the proposed solution.

Make Interactive Stacked Normalized Area Chart

figure_title="Confusion Matrix Evolution over Model Training Epochs"
columns = ["tp", "tn", "fp", "fn"]
colors = ['#d7191c','#fdae61','#abdda4','#2b83ba']
palette = dict((column, color) for column, color in zip(columns, colors))

# create interactive chart with user-defined function
make_stacked_normalized_chart(
    df, 
    x="epoch", 
    columns=columns, 
    palette=palette, 
    figure_title=figure_title,
)

enter image description here

Define Custom Plotting Function

Here we define a custom function (make_stacked_normalized_chart()) to create an interactive-stacked-normalized-area-chart.

import plotly.graph_objects as go
from typing import List, Dict

def make_stacked_normalized_chart(df: pd.DataFrame, x: str, 
                                  columns: List[str], 
                                  palette: Dict[str, str], 
                                  figure_title: str="Figure Title"):
    """Create a stacked normalized interactive chart with Plotly library."""
    x_label = x
    x = df[x_label]
    fig = go.Figure()

    def add_trace(column: str):
        fig.add_trace(go.Scatter(
            x=x, y=df[column],
            text=column, # set the name shown while hovering over
            name=column, # set the name in the legend
            # fill='toself',
            hoveron = 'points+fills', # select where hover is active
            hoverinfo='text+x+y',
            mode='lines',        
            line=dict(width=0.5, color=palette.get(column)),
            stackgroup='one', # define stack group
            groupnorm='percent', # sets the normalization for the sum of the stackgroup
        ))

    for column in columns:
        add_trace(column)

    fig.update_layout(
        title_text=figure_title,
        showlegend=True,
        xaxis=dict(
            type="category",
            title=x_label,
        ),
        yaxis=dict(
            type='linear',
            range=[1, 100],
            ticksuffix='%',
        ),
    )
    fig.show()

Dummy Data

We will use the following data to demonstrate the interactive-stacked-normalized-area-chart.

import numpy as np
import pandas as pd

np.random.seed(42)

nrows = 100
x = np.arange(nrows)
tp = 60 + np.arange(nrows) + np.random.randint(0, 20, nrows)
tn = 25 + np.arange(nrows) + np.random.randint(0, 20, nrows)
fp = np.random.randint(2, 6, nrows) + np.random.randint(0, 8, nrows)
fn = np.random.randint(4, 7, nrows) + np.random.randint(3, 6, nrows)

df = pd.DataFrame({"epoch": x, "tp": tp, "tn": tn, "fp": fp, "fn": fn})