Exclude subplots without any data and left-align the rest in relplot

332 Views Asked by At

Related to this question: Use relplot to plot a pandas dataframe leading to error

Data for reproducible example is here:

import pandas as pd

data = {'Index': ['TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN'],
        'Stage': [10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28],
        'Z-Score CEI': [-0.688363146221944, 0.5773502691896258, -0.1132178081286216, -0.4278470185781525, 1.0564189237269357, -0.2085144140570746, -0.2085144140570747, 0.2094308186874662, 0.7196177629619716, 0.0, 0.2085144140570762, -1.3803992008056865, -1.3414801279616884, -0.898669162696764, -0.3015113445777637, -0.2953788838542738, 1.1753566728623484, 0.887285779752818, -0.7071067811865475, 0.2847473987257496, 0.1877402877114761, -0.14246249364941, 0.9686648999069224, -0.3015113445777636, -0.2734952011457535, 0.5888914135578924, -0.4488478006064821, -0.7745966692414834, 0.3052145041378634, 0.8197566686157259, 0.3377616284580471, 1.1832159566199232, -0.3015113445777637, -0.2952684241380082, -0.7971688059921156, 0.4479595231454734, -0.5805577953661853, 0.3015113445777642, -0.610500944190139, -0.7734588159553295, -0.5434722467562666, -0.2085144140570747, -0.2085144140570747, 0.8838570486142397, -0.7976091842744983, 2.213211486674006, 0.3779644730092272, -0.6900911175081499, -0.4856558012299846, -0.6044504143545613, -0.2085144140570746, -0.2085144140570747, 1.6498242899497324, 0.463638205246897, -0.064684622735315, 0.5488212999484522, -0.665392754456709, -1.096398502672124, 0.9387247898517332, -0.2085144140570747, -0.2085144140570748, 1.5486212537866115, 0.6776076459912243, -0.7973761651368712, 0.4773960376293314, 0.2611306759187019, -0.2450438178293888, 0.1097642599896903, -0.2085144140570746, -0.2085144140570747, 1.2468175442040146, 0.4912008775378222, -0.8071397220005339, 0.3015113445777636, -0.4051430868010012, -0.9843673918740764, 0.4231429298696365, -0.2085144140570746, -0.2182178902359924, 1.0617336112420042, 0.4221998839727844, -0.2267786838055363, 0.2847473987257496, 1.2708306299144654, 2.4058495687034616, -0.1042572070285372, 4.79583152331272, 4.79583152331272, -0.1758750648062869, 0.9614146130140746, -0.6493094697110509, 0.2847473987257496, -0.0566333001085325, 0.0970016157961683, -0.3380617018914065, -0.2085144140570746, -0.2132007163556104, 1.6462867435913509, 0.8920062635166146, -0.649519052838329, 0.2847473987257496, -0.5727902328114448, -0.385256843427376, 0.123403510468459, -0.2085144140570747, -0.2085144140570747, 0.7206954054604126, -0.0169294393471337, -0.1547646465068273, 0.3900382256192578, -0.91200685504817, -0.7643838011372592, -0.8553913029328061, -0.2085144140570746, -0.2132007163556104, 1.999517273479448, 0.2135313581345105, 0.3577708763999664, 0.2085144140570741, -0.5245759407883583, -0.3972170332271401, 0.1363988678940945, -0.2085144140570746, -0.2085144140570747, 2.180043023382912, 0.6949201395674811, -0.0345238339879863, 0.3872983346207417, -1.054383845470446, -0.7524909974608698, -0.79555728417573, -0.2085144140570747, -0.2085144140570747, 2.597515932302782, -0.0173575308522844, -0.7839294959021852, 0.5496481403962044, 0.3346732026206391, -0.1729151200242987, 0.8108848540793832, -0.2085144140570747, -0.2085144140570747, -0.1975075078549267, -0.1333012766349092, -0.7300956427599692, 0.3495310368212778, -0.9383516638143292, 0.3757624051611033, -0.9198662110078, -0.2085144140570747, -0.2085144140570747, 0.1077379509580834, -0.0391099277150297, -0.8006407690254357, 0.5226257719601375, 0.2650955994479978, -0.3323178678594628, 1.348187695720845, -0.2085144140570746, -0.2085144140570748, 0.6009413558916348, 0.455353435995126, -0.5933908290969269, 0.0, 0.1226864783178058, -0.0252747129054563, 0.8212299340934688, -0.2085144140570746, -0.2132007163556105, -0.8954835101738379, -1.1134420487718968],
        'Type': ['Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI']}

df = pd.DataFrame(data)

I want to plot the data; rows should be based on the column Type, cols should be based on the column Index, the x-axis should be Z-Score CEI, and the y-axis should be based on Stage column. Currently, I am using relplot to do this:

df = df.groupby('Index').filter(lambda x: not x['Z-Score CEI'].isna().all())
df["Type"] = df["Type"].astype("category")
df["Index"] = df["Index"].astype("category")

df["Type"] = df["Type"].cat.remove_unused_categories()
df["Index"] = df["Index"].cat.remove_unused_categories()

g = sns.relplot(
    data=df,
    x='Z-Score CEI',
    y='Stage',
    col='Index',
    row='Type',
    facet_kws={'sharey': True, 'sharex': True},
    kind='line',
    legend=False,
)

for (i,j,k), data in g.facet_data():
    if data.empty:
        ax = g.facet_axis(i, j)
        ax.set_axis_off()

However, this leads to a plot where the empty plots are distorting the placement of the subplots with data. I want there to be no empty areas.

Current output looks like so: rder

In the graphic above, I want to remove all the subplots which have no data. This will result in different rows having different number of subplots e.g. 1st row might have 5 subplots and 2nd row will have only 4 subplots etc.

I want each row to only have the same Type, not mix multiple Types.

2

There are 2 best solutions below

0
cottontail On BEST ANSWER

Here is another solution that is based on @mwaskom's suggestion in the comments. The basic idea is to create an auxiliary column where for each Type, existing Index values are labeled 0,1,2,... which will act as the column index in the FacetGrid. Then after plotting the relplot, remove all Axes without data and fix the title of the ones with data by replacing the column index by the "real" Index value.

# label existing Type-Index pairs
col_idx = df.value_counts(['Type', 'Index']).groupby(level=0, observed=False).cumcount().astype(str)
# map the labels back to the dataframe
df1 = df.merge(col_idx.reset_index(name='column_loc'), on=['Type', 'Index'], how='left')

# plot replot
g = sns.relplot(
    data=df1,             # <--- new dataframe
    x='Z-Score CEI',
    y='Stage',
    col='column_loc',     # <--- column is by the newly created column
    row='Type',
    facet_kws={'sharey': True, 'sharex': True},
    kind='line',
    legend=False,
)
for ax in g.axes.flat:
    if not ax.lines:
        g.fig.delaxes(ax) # remove empty subplots
    else:
        # fix the title
        typ, loc = (x.split(' = ')[1] for x in ax.get_title().split(' | '))
        idx, = col_idx[col_idx==loc].loc[typ].index
        ax.set_title(f"Type = {typ} | Index = {idx}")

result

I think for this particular task, matplotlib is very easy to use IMO. It's because both Type and Index columns are dtype Categorical, so by passing observed=True to pandas groupby, we can simply drop Index values that don't exist for each Type. Basically, we can use a nested groupby to create a sub-dataframe which can be fed into lineplot. However, because we need to manually plot each lineplot, it may be slow (maybe not since relplot is slow anyway).

import matplotlib.pyplot as plt
gby_obj = df.groupby('Type', observed=True)
nrows = gby_obj.ngroups
ncols = gby_obj['Index'].nunique().max()

fig, axs = plt.subplots(nrows, ncols, figsize=(20,20), sharey=True, sharex=True)
for i, (typ, g1) in enumerate(gby_obj):
    for j, (idx, g2) in enumerate(g1.groupby('Index', observed=True)):
        sns.lineplot(data=g2, x='Z-Score CEI', y='Stage', ax=axs[i,j])
        axs[i,j].set_title(f'Type = {typ} | Index = {idx}')
    for a in axs[i,j+1:]:
        fig.delaxes(a)
sns.despine(fig, top=True, right=True)
fig.tight_layout()
0
puchal On

I don't believe if you could achieve want you want by using relplot.

What I would suggest, is to create FacetGrid and then for each rows adjust number of columns by deleting the most right plots if needed, then create lineplot for each Type | Index.

import seaborn as sns
import matplotlib.pyplot as plt


df = df.groupby('Index').filter(lambda x: not x['Z-Score CEI'].isna().all())
df["Type"] = df["Type"].astype("category")
df["Index"] = df["Index"].astype("category")

df["Type"] = df["Type"].cat.remove_unused_categories()
df["Index"] = df["Index"].cat.remove_unused_categories()

# Get unique types
types = df['Type'].unique()
# Get unique indexes
indices = df['Index'].unique()

indices_len = len(indices)

# Create FaceGrid
g = sns.FacetGrid(df, col="Index", row="Type")

row_number = 0
for t in types:
    
    # Filter for Type
    df_type = df[df['Type'] == t]
    # Determine the number of columns for this type
    num_cols = len(df_type['Index'].unique())
    
    # Remove plots that are not needed
    for i in range(indices_len - num_cols):
        ax = g.facet_axis(row_number, indices_len - i - 1)
        plt.delaxes(ax)
        
    
    col_number = 0
    for index in indices:
        # Filter data for each index
        df_index = df_type[df_type['Index'] == index]
        if not df_index.empty:
            # if dataframe not empty add lineplot
            ax = g.facet_axis(row_number, col_number)
            sns.lineplot(data=df_index, x='Z-Score CEI', y='Stage', ax=ax)
            ax.set_title(f"Type: {t} | Index: {index}")
            col_number += 1
    
    row_number += 1

Output plot: enter image description here