How can I use jit and vmap in JAX to vectorize and speed up the following computation:
@jit
def distance(X, Y):
"""Compute distance between two matrices X and Y.
Args:
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.mean(jnp.abs(X - Y))
@jit
def compute_metrics(idxs, X, Y):
results = []
# Iterate over idxs
for i in idxs:
if i:
results.append(distance(X[:, i], Y[:, i]))
return results
#data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
#indices
idxs = ((7,8), (1,7,9), (), (1), ())
# call the regular function
print(compute_metrics(idxs, X, Y)) # works
# call the function with vmap
print(vmap(compute_metrics, in_axes=(None, 0, 0))(idxs, X, Y)) # doesn't work
I followed the JAX website and tutorials but I can't find out how to make this work. The non vmap version works. However, I get an IndexError for the the vmap version (last line above) that looks like this:
jax._src.traceback_util.UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
Any idea how I could get this to work? Also idxs might change and be any arbitrary valid combination of indices e.g.
idxs = ((1,3,4,5), (3,9), (3,2,5), (), (5,8))
As explained above, I tried the above version with and without vmap and couldn't get the latter, vmap, version to work.
I don't think vmap going to work with tuple of scalars. What you need is to put indices into array and vmap over it.
I am not sure if this solution satisfies you because we have to get rid of empty indices pairs ().
You can also jit everything:
Update 19/05/2023:
The question is how to make it more general - to have variable number of indices. The problem here is that JAX needs static shapes of input and output, therefore we need some tricks how to deal with this. The most obvious trick in such cases is to use jnp.where function to handle this conditional behavior. The other choice is jax.lax.cond. Therefore as before, we put indices into an array but this time we set -1 as a special flag indicating this is empty fill in the matrix (this is like zero-padding but with -1 instead of 0s). Because arrays have static shape, the number of columns in idxs_pairs should be the max number of pairs.
For example:
We now redefine our new functions:
I am not sure this is the most optimal way of doing it - it depends if XLA compiler can catch that we set distance of zero for -1 indices, but I am not an XLA expert. I will later provide another solution based on jax.lax.cond which can be faster, so we can benchmark.
Update: 22/05/2023 In case of jax.lax.cond the implementation can look like this:
I tested it and execution times are the same as for jnp.where case.