I am trying to use @jit with nested function, having a problem.
I have a class One that take in another class Plant with a method func.
I would like to call this method jitted func from One.
I think that I followed the FAQ of JAX, "How to use jit with methods?" section.
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods
However, I encountered an error saying that
TypeError: One.__init__() got multiple values for argument 'plant'.
Would anyone tell me how to solve this?
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
The last line gives me an error described above.
You have issues in the
tree_flattenandtree_unflattencode in both classes.One._tree_flattentreatsplantas static data, but it is not: it is a pytree that has non-static elements.One._tree_unflatteninstantiatesOnewith arguments in the wrong order, leading to the error you're seeingPlant.__init__does nothing with thekkargument.Plant._tree_unflattenis missing theaux_dataargument, and fails to pass thekkargument toPlant.__init__With these issues fixed, your code executes without error: