Jax vmap, in_axes doesn't work if keyword argument is passed

587 Views Asked by At

The parameter in_axes in vmap seems to only work for positional arguments.
But throws AssertionError (with no message) called with keyword argument.

from jax import vmap
import numpy as np

def foo(a, b, c):
    return a * b + c

foo = vmap(foo, in_axes=(0, 0, None))

aj, bj = np.random.rand(2, 100, 1)
foo(aj, bj, 10)  # works
foo(aj, bj, c=10)  # throws error

console

Traceback (most recent call last):
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api_util.py", line 300, in flatten_axes
    tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\tree_util.py", line 183, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\tree_util.py", line 183, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Tuple arity mismatch: 2 != 3; tuple: (<object object at 0x00000187F7BF4380>, <object object at 0x00000187F7BF4380>).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\IPython\core\interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-20500a2f8a08>", line 1, in <module>
    runfile('C:\\Users\\Amith\\PycharmProjects\\nntp\\tests\\test2.py', wdir='C:\\Users\\Amith\\PycharmProjects\\nntp\\tests')
  File "C:\Program Files\JetBrains\PyCharm 2022.2\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2022.2\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:\Users\Amith\PycharmProjects\nntp\tests\test2.py", line 11, in <module>
    foo(aj, bj, c=10)
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api.py", line 1481, in vmap_f
    in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
  File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api_util.py", line 306, in flatten_axes
    assert treedef_is_leaf(leaf)
AssertionError

how would one go about running foo as foo(aj, bj, c=10) without provoking the error?

1

There are 1 best solutions below

0
On BEST ANSWER

Yes, it's true that vmap in_axes only works for positional arguments. If you want to make a more general vmapped function, the best option currently is probably to use a wrapper function. For example:

def _foo(a, b, c):
    return a * b + c

def foo(a, b, c):
  return vmap(_foo, in_axes=(0, 0, None))(a, b, c)