Creating a categorical heatmap with sparklines?

66 Views Asked by At

Does anyone know of an example of how to create a categorical heat map with individual sparklines within each cell? Or have a suggestion on how to use matplotlib's annotation to produce this (or something similar)?

Essentially turning this: Matplotlib heatmap annotation

into this: Heatmap with sparkline

enter image description here

1

There are 1 best solutions below

0
On BEST ANSWER

Assuming such a format as input (an arbitrary number of rows for each combination of row/col) and that we want to plot a heatmap with the average value per row/col, and a small line for each row/col combination with the consecutive values:

    row col     value
0     A   a -2.911793
1     A   a -3.066935
2     A   a -0.940881
3     A   a  1.838795
4     A   a  2.359492
..   ..  ..       ...
595   E   f -3.233857
596   E   f -4.348279
597   E   f -4.236598
598   E   f -4.697110
599   E   f -3.618638

[600 rows x 3 columns]

You could plot a heatmap using sns.heatmap on the reshaped data (with pivot_table, here using the mean of the data per group), then rework the data to plot a line on top of it:

import seaborn as sns

ax = sns.heatmap(df.pivot_table(index='row', columns='col',
                                values='value', aggfunc='mean'))

margin = 0.1

def norm(s, margin=0):
    '''Normalizes the input Series between 0+margin and 1-margin'''
    MIN = s.min()
    return (s-MIN)/(s.max()-MIN)*(1-2*margin)+margin

tmp = (df
       .sort_values(by=['row', 'col']) # ensure data is sorted
       # compute index/col position per group to match the heatmap
       .assign(row_id=lambda d: pd.factorize(d['row'])[0],
               col_id=lambda d: pd.factorize(d['col'])[0],
               # deduplicate the data to form a x-value and shift per col
               x=lambda d: (x:=d.groupby(['row_id', 'col_id']
                           ).cumcount())/x.max()+d['col_id'],
               # normalize the data and shift per row
               norm_value=lambda d: (norm(d['value'], margin=0.1).rsub(1)
                                     + d['row_id']
                                    ).mask(d['col_id'].ne(d['col_id'].shift())),
               )
      )

tmp.plot(x='x', y='norm_value', ax=ax, legend=False)

Example output:

heatmap + lineplot for each cell

Reproducible input:

import numpy as np
import pandas as pd
from string import ascii_uppercase, ascii_lowercase

R, C, N = 5, 6, 20

np.random.seed(0)
a = np.arange(R*C*N)
df = (pd.DataFrame({'row': np.array(list(ascii_uppercase))[a//(C*N)],
                    'col': np.array(list(ascii_lowercase))[a%(C*N)//N],
                    'value': 10*np.sin(a/N*5)+np.random.normal(scale=2, size=R*C*N),
                   })
        .assign(value=lambda d: d.groupby(['row', 'col'])['value']
                .transform(lambda s: s*np.random.uniform(0.1, 1)+np.random.uniform(-10, 10)))
        #.sample(frac=0.7).sort_index()
     )

Alternative output when .sample(frac=0.7).sort_index() is uncommented (to simulate uneven groups):

heatmap + lineplot for each cell