I'm trying to learn a bit about pyro and building probabilistic neural networks with pytorch. Normally, with a pytorch.nn.Module I can move it to the GPU with model.to( 'cuda') however this does not seem to work with a pyro Module. How does one correctly place a pyro Module model onto the GPU?
Example Model:
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
class Model(PyroModule):
def __init__(self, h1=20, h2=20):
super().__init__()
self.fc1 = PyroModule[nn.Linear](1, h1)
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, 1]).to_event(2))
self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([h1]).to_event(1))
self.fc2 = PyroModule[nn.Linear](h1, h2)
self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([h2, h1]).to_event(2))
self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))
self.fc3 = PyroModule[nn.Linear](h2, 1)
self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([1, h2]).to_event(2))
self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([1]).to_event(1))
self.relu = nn.ReLU()
def forward(self, x, y=None):
x = x.reshape(-1, 1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
mu = self.fc3(x).squeeze()
sigma = pyro.sample("sigma", dist.Uniform(0., 1.))
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
return mu
then:
model = Model()
however model.to( 'cuda') does not seem to actually move the model to the GPU.
Update: I'm not sure if this is a correct solution...
I find that if I replace PyroSample with pyro.nn.PyroParam then they are listed in the ParamDict and can be moved to the gpu.
I've run into a similar problem. Even though
PyroModuleis subclassed fromnn.Module, as far as I can tell, the.tomethod does not work to carry overPyroSampleobjects the same way one might expect fornn.Parameterobjects. (I think it might work forPyroParam, like you say).I found this Pyro forum post to have a workable solution. It says you can initialize the
PyroSample(dist...)calls with tensors that are already on the GPU. For example, change lines like yourto
I've found this is only a problem when you have multivariate
PyroSampleobjects (i.e., you have the.expand..to_event..here). If it's a univariate object then the variable transfers from the CPU to GPU without complaining.