As the title says, I currently manually hstack()
the first axis of a 3D array returned by jax.vmap()
. In my code, the copy operation in hstack()
is a currently a speed bottleneck. Can I avoid this by instructing jax.vmap()
to do this directly?
Here is a simplified example:
import jax
import jax.numpy as jnp
def f(a, b, c):
return jnp.array([[a.sum(), b.sum()], [c.sum(), 0.]]) # Returns a 2x2 array
def arr(m, n):
return jnp.arange(m*n).reshape((m, n))
m = 3
a = arr(m, 2)
b = arr(m, 5)
c = arr(m, 7)
fv = jax.vmap(f)
vmap_output = fv(a, b, c)
desired_output = jnp.hstack(fv(a, b, c))
print(vmap_output)
print(desired_output)
This yields:
# vmap() output
[[[ 1. 10.]
[ 21. 0.]]
[[ 5. 35.]
[ 70. 0.]]
[[ 9. 60.]
[119. 0.]]]
# Desired output
[[ 1. 10. 5. 35. 9. 60.]
[ 21. 0. 70. 0. 119. 0.]]
If this is not possible, I would resort to pre-allocating an array and simply writing to the columns manually, but I hope to avoid this. Thanks for any clue!
Update from @jakevdp's answer
Alright, it isn't possible. So I resort to writing to the columns, but this fails as well:
def g(output, idx, a, b, c):
block = jnp.array([[a.sum(), b.sum()], [c.sum(), 0.]]) # Returns a 2x2 array
jax.lax.dynamic_update_slice_in_dim(output, block, idx*2, axis=1)
# Defined above: jax, jnp, m, a, b, c
g_output = jnp.zeros((2, 2*m))
idxs = jnp.arange(m)
gv = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))
gv(g_output, idxs, a, b, c)
print(g_output)
This yields:
[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]
So writing to g_output
in the function g
is not retained. Is there a way around this?
No,
vmap
does not have any built-in capability to stack outputs differently than the batching semantics would imply. But if you're interested in fusing thehstack
operation with thevmap
operation to the extent possible, you could do so by wrapping it injit
. For example:Edit: responding to your edited question: the reason the result is all zeros is because your function doesn't do anything: it returns
None
, so there's no way for it to affect the input array calledg_output
. JAX requires pure functions so side-effecting code like what you wrote above is not compatible. If you wanted to replace thehstack
with an indexed update, you could do something like this:but a nontrivial
scatter
operation like this will not typically be faster than a simplereshape
, especially if you're running on an accelerator like GPU.If your arrays are large enough that reshapes are costly, you might find that a more direct implementation is better; for example: