Python PyTorch Pyro - Multivariate Distributions

869 Views Asked by At

How does one sample a multivariate distribution in Pyro? I just want a (M, N) Beta distribution, but the following doesn't work:

impor torch
import pyro
with pyro.plate("theta_plate", M):
    theta = pyro.sample("theta",
                        pyro.distributions.Beta(concentration0=torch.ones(N),
                                                concentration1=torch.ones(N)))

2

There are 2 best solutions below

0
On BEST ANSWER

Use to_event(n) to declare depdent samples.

import torch
import pyro
import pyro.distributions as dist

def model(N, M):
    with pyro.plate("theta_plate", M):
        theta = pyro.sample("theta", dist.Beta(torch.ones(N),1.).to_event(1))
    return theta


if __name__ == '__main__':
    print(model(10,12).shape) # (10,12)
1
On

For both PyTorch and Pyro distributions, the syntax is the same:

import pyro.distributions as dist

samples = dist.Beta(2, 2).sample([200]) # Will draw 200 samples.

You shouldn't need to the plate notion unless if you're only wanting to sample a distribution.