Rewriting for loop with jax.lax.scan

2k Views Asked by At

I'm having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan?

numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
      for n in row:
         if n % 2 == 0:
            evenNumbers += 1
1

There are 1 best solutions below

0
On BEST ANSWER

Assuming a solution should demonstrate the concepts rather than optimize the example shown, the function to be jax.lax.scanned must match the expected signature and any dynamic condition has to be replaced with jax.lax.cond. The code below is the closest to the original I could think of, but please be aware that I'm anything but an jaxpert.

import jax
import jax.numpy as jnp

def f(carry, row):

    even = 0
    for n in row:
        even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)

    return carry + even, even

numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
jax.lax.scan(f, 0, numbers)

Output

(DeviceArray(2, dtype=int32, weak_type=True),
 DeviceArray([1, 0, 1], dtype=int32, weak_type=True))