I'm trying to code an MCMC for the lorenz system (based on the following example of the predator-prey model - https://num.pyro.ai/en/stable/examples/ode.html). I used the structure of the example program and simply replaced the model, however, I am running into the following error:
ValueError: Normal distribution got invalid loc parameter.
I noticed that in the original program, theta has 4 components and the loc/scale parameters also had 4 elements in their array argument, so I assumed these numbers have to be the same. Is this an invalid assumption? If so, how can I correct the loc/scale parameters to get the desired outcome?
Note: the "data" array was calculated using the RK4 method to simulate a lorenz system and output as a jax.numpy.array()
import os
import matplotlib
import matplotlib.pyplot as plt
from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
def da_dt(a,t,theta):
x = a[0]
y = a[1]
z = a[2]
sigma,rho,beta = (
theta[...,0],
theta[...,1],
theta[...,2],
)
dx_dt = sigma*(y-x)
dy_dt = (x*(rho-z))-y
dz_dt = (x*y)-(beta*z)
return jnp.stack([dx_dt,dy_dt,dz_dt])
def model(N,b=None):
a_init = numpyro.sample("a_init",dist.LogNormal(jnp.log(10),1).expand([3]))
ts = jnp.arange(float(N))
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(
low=0.0,
loc=jnp.array([1.0,0.05,1.0]),
scale=jnp.array([0.5,0.05,0.5]),
),
)
a = odeint(da_dt,a_init,ts,theta,rtol=1e-6,atol=1e-5,mxstep=1000)
sigma = numpyro.sample("sigma",dist.LogNormal(-1,1).expand([3]))
numpyro.sample("y",dist.LogNormal(jnp.log(a),sigma),obs=b)
n_warmup = 1000
n_samples = 1000
n_chains = 1
mcmc = MCMC(
NUTS(model,dense_mass=True),
num_warmup = n_warmup,
num_samples = n_samples,
num_chains = n_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(PRNGKey(1),N=data.shape[0],b=data)
mcmc.print_summary()