Bayesian framework: Plot prior predictive and posterior predictive distribution with arviz

342 Views Asked by At

I'm attempting to replicate the example in Figure 1.5 from this source, which illustrates the prior predictive and posterior predictive distributions. However, I'm facing difficulty comprehending what exactly needs to be plotted. Given that I intend to create a histogram of predictions while working with 4 chains, 1000 draws, and 20 observations, I'm uncertain about how to manage the dimensions in order to construct the desired histograms. I've tried using the sum approach along the chain and draws dimensions, as demonstrated in the example, but it doesn't yield results that closely resemble the outcome depicted in the reference material. I have also tried to use the mean along these dimensions, which I thought would be better. I want to have the probability on the y-axis and the 20 observations on the x-axis.

Here's the code snippet I'm working with:

import pymc as pm
import matplotlib.pyplot as plt
import scipy.stats as stats
import arviz as az

az.style.use("arviz-grayscale")
plt.rcParams['figure.dpi'] = 300
np.random.seed(521)
viridish = [(0.2823529411764706, 0.11372549019607843, 0.43529411764705883, 1.0),
            (0.1450980392156863, 0.6705882352941176, 0.5098039215686274, 1.0),
            (0.6901960784313725, 0.8666666666666667, 0.1843137254901961, 1.0)]

Y = stats.bernoulli(0.7).rvs(20)
# Declare a model in PyMC3
with pm.Model() as model:
    # Specify the prior distribution of the unknown parameter
    θ = pm.Beta("θ", alpha=1, beta=1)

    # Specify the likelihood distribution and condition on the observed data
    y_obs = pm.Binomial("y_obs", n=1, p=θ, observed=Y)

    # Sample from the posterior distribution
    idata = pm.sample(1000, return_inferencedata=True)

pred_dists = (pm.sample_prior_predictive(1000, model=model),
              pm.sample_posterior_predictive(idata,model=model))

# Prior
prior_samples = pred_dists[0]['prior']['θ'].values # Prior observed
prior_pred_samples = pred_dists[0]['prior_predictive'] # Prior predictions

# Posterior
post_distribution = idata.posterior["θ"] # Posterior observed
posterior_pred_samples = pred_dists[1]['posterior_predictive'] # Posterior Predictions

posterior_pred_graph = posterior_pred_samples['y_obs'].sum(dim=['draw','chain'])
prior_pred_samples_graph = prior_pred_samples['y_obs'].sum(dim=['draw','chain'])

fig,axes =plt.subplots(4,1,gridspec_kw={'hspace': 0.1})

az.plot_dist(prior_samples, plot_kwargs={"color":"0.5"},
             fill_kwargs={'alpha':1}, ax=axes[0])
axes[0].set_title("Prior distribution", fontweight='bold',fontsize=10)
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 4)
axes[0].tick_params(axis='both', pad=7)
axes[0].set_xlabel("θ")


az.plot_dist(prior_pred_samples_graph, plot_kwargs={"color":"0.5"},
             fill_kwargs={'alpha':1}, ax=axes[1])
axes[1].set_title("Prior predictive distribution", fontweight='bold',fontsize=10)
# axes[1].set_xlim(-1, 21)
# axes[1].set_ylim(0, 0.15)
axes[1].tick_params(axis='both', pad=7)
axes[1].set_xlabel("number of success")

az.plot_dist(post_distribution, plot_kwargs={"color":"0.5"},
             fill_kwargs={'alpha':1},ax=axes[2])
axes[2].set_title("Posterior distribution", fontweight='bold',fontsize=10)
axes[2].set_xlim(0, 1)
axes[2].set_ylim(0, 5)
axes[2].tick_params(axis='both', pad=7)
axes[2].set_xlabel("θ")

az.plot_dist(posterior_pred_graph, plot_kwargs={"color":"0.5"},
             fill_kwargs={'alpha':1}, ax=axes[3])
axes[3].set_title("Posterior predictive distribution", fontweight='bold',fontsize=10)
# axes[3].set_xlim(-1, 21)
# axes[3].set_ylim(0, 0.15)
axes[3].tick_params(axis='both', pad=7)
axes[3].set_xlabel("number of success")

enter image description here

I would appreciate assistance in clarifying the approach for creating these histograms. The intention is to have the number of successes on the x-axis (20 observed data points) and the corresponding probabilities on the y-axis. This visual representation is crucial for my understanding of the predictions generated by my actual model. I am eager to grasp this concept so that I can confidently generate similar plots with my real data.

Thank you for your guidance.

0

There are 0 best solutions below