How to vectorize JAX functions using jit compilation and vmap auto-vectorization

382 Views Asked by At

How can I use jit and vmap in JAX to vectorize and speed up the following computation:

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    results = []
    # Iterate over idxs
    for i in idxs:
        if i:
            results.append(distance(X[:, i], Y[:, i]))
    return results

#data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
#indices
idxs = ((7,8), (1,7,9), (), (1), ())

# call the regular function
print(compute_metrics(idxs, X, Y)) # works
# call the function with vmap
print(vmap(compute_metrics, in_axes=(None, 0, 0))(idxs, X, Y)) # doesn't work

I followed the JAX website and tutorials but I can't find out how to make this work. The non vmap version works. However, I get an IndexError for the the vmap version (last line above) that looks like this:

jax._src.traceback_util.UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

Any idea how I could get this to work? Also idxs might change and be any arbitrary valid combination of indices e.g.

idxs = ((1,3,4,5), (3,9), (3,2,5), (), (5,8))

As explained above, I tried the above version with and without vmap and couldn't get the latter, vmap, version to work.

1

There are 1 best solutions below

6
On

I don't think vmap going to work with tuple of scalars. What you need is to put indices into array and vmap over it.

I am not sure if this solution satisfies you because we have to get rid of empty indices pairs ().

idxs_pairs = jnp.array([[7,8],[7,9]]) # put the indices pairs into array

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    return distance(X[:,idxs], Y[:,idxs])

vmap(compute_metrics, in_axes=(0, None, None))(idxs_pairs, X, Y)

You can also jit everything:

jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

Update 19/05/2023:

The question is how to make it more general - to have variable number of indices. The problem here is that JAX needs static shapes of input and output, therefore we need some tricks how to deal with this. The most obvious trick in such cases is to use jnp.where function to handle this conditional behavior. The other choice is jax.lax.cond. Therefore as before, we put indices into an array but this time we set -1 as a special flag indicating this is empty fill in the matrix (this is like zero-padding but with -1 instead of 0s). Because arrays have static shape, the number of columns in idxs_pairs should be the max number of pairs.

For example:

# 7, 8, -1 -> we only use indices: 7, 8 
# 7, 9, -1 -> we only use indices: 7, 9 
# 7, 5, 6 -> we use indices: 7, 5, 6 
# 1, -1, -1 -> we use only index: 1 
idxs_pairs = jnp.array([[7, 8, -1], [7, 9, -1], [7, 5, 6], [1, -1, -1]]) # put the indices pairs into array

We now redefine our new functions:

def distance_vectors(idx, X, Y):
    """Compute distance between two vectors of matrices X and Y.

    Args:
        idx (jax.numpy.ndarray): scalar indicating index of column
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.abs(X[:,idx] - Y[:,idx])

def compute_metrics(idxs, X, Y):
  distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
  distances = distances.T * jnp.where(idxs >= 0, 1, 0)
  n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
  output = 1/n_of_actual_indices *  1/X.shape[0] * jnp.sum(distances)
  return output

output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

I am not sure this is the most optimal way of doing it - it depends if XLA compiler can catch that we set distance of zero for -1 indices, but I am not an XLA expert. I will later provide another solution based on jax.lax.cond which can be faster, so we can benchmark.

Update: 22/05/2023 In case of jax.lax.cond the implementation can look like this:

def distance_vectors(idx, X, Y):
    """Compute distance between two vectors of matrices X and Y.

    Args:
        idx (jax.numpy.ndarray): scalar indicating index of column
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return lax.cond(idx >= 0, lambda: jnp.abs(X[:,idx] - Y[:,idx]), lambda: jnp.zeros_like(X[:,idx]))

def compute_metrics(idxs, X, Y):
  distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
  n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
  output = 1/n_of_actual_indices *  1/X.shape[0] * jnp.sum(distances)
  return output

output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

I tested it and execution times are the same as for jnp.where case.