I am trying to use @jit
with nested function, having a problem.
I have a class One
that take in another class Plant
with a method func
.
I would like to call this method jitted func
from One
.
I think that I followed the FAQ of JAX, "How to use jit with methods?" section.
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods
However, I encountered an error saying that
TypeError: One.__init__() got multiple values for argument 'plant'
.
Would anyone tell me how to solve this?
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
The last line gives me an error described above.
You have issues in the
tree_flatten
andtree_unflatten
code in both classes.One._tree_flatten
treatsplant
as static data, but it is not: it is a pytree that has non-static elements.One._tree_unflatten
instantiatesOne
with arguments in the wrong order, leading to the error you're seeingPlant.__init__
does nothing with thekk
argument.Plant._tree_unflatten
is missing theaux_data
argument, and fails to pass thekk
argument toPlant.__init__
With these issues fixed, your code executes without error: