how to correctly set legend and colors with arviz.plot_bpv

109 Views Asked by At

I'm trying to manually specify colors for lines in arviz.plot_bpv to evaluate the impact of a model parameter. However, I'm facing challenges in correctly displaying the legend.

How should I set the legend in order to each line be different? For example, in the image below, sigma 1 and sigma 3 has the same black dashed line. Also the legend in the ax1 is not ok.

import matplotlib
import pymc as pm
import arviz as az
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
RANDOM_SEED = 58
rng = np.random.default_rng(RANDOM_SEED)

data = pd.read_csv(pm.get_data("babies.csv"))
X = data.Month.values
Y = data.Length.values
sigma_test = [1,10,100,1000]

with pm.Model() as model:
    Y = pm.ConstantData('Y',Y)
    X = pm.MutableData('X',X)
    sigma = pm.MutableData('sigma_test',sigma_test[0])
    
    intercept = pm.LogNormal('Intercept', mu=3., sigma=1.)
    slope = pm.Normal('Slope', mu=0.5, sigma=0.1)
    sigma = pm.HalfNormal('sigma', sigma=sigma)
    mu = pm.Deterministic('mu',slope * X + intercept)
    
    likelihood = pm.Normal('y_est', mu=mu, sigma=sigma, observed = Y,shape=X.shape[0])
    
# Generate one trace for each sigma
trace = []
post_pred = []
for sigma_vals in sigma_test:
    with model:
        # Switch out the sigma
        pm.set_data({"sigma_test": sigma_vals})
        trace_i = pm.sample(random_seed=rng)
        trace.append(trace_i)
        post_pred.append(pm.sample_posterior_predictive(trace_i,extend_inferencedata=True,random_seed=rng))
        
# Plot

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
legend_labels = []  # To store legend labels
colors = ['C0', 'C1', 'C2', 'C3','C4','C5','C6']
for i in range(len(sigma_test)):
    post_pred_i = post_pred[i]
    az.plot_bpv(post_pred_i, kind="p_value",color=colors[i], ax=ax[0])
    az.plot_bpv(post_pred_i, kind="u_value",color=colors[i], ax=ax[1])
    
    # Append a label to the legend_labels list
    legend_labels.append(f"Sigma {i}")

# Make sure to adjust the labels or titles for your subplots as needed
ax[0].set_title("P-Values")
ax[1].set_title("U-Values")

# Add a legend to the subplots
ax[0].legend(legend_labels)
ax[1].legend(legend_labels)

plt.show()

enter image description here

0

There are 0 best solutions below