I'm having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan
?
numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
for n in row:
if n % 2 == 0:
evenNumbers += 1
Assuming a solution should demonstrate the concepts rather than optimize the example shown, the function to be
jax.lax.scan
ned must match the expected signature and any dynamic condition has to be replaced withjax.lax.cond
. The code below is the closest to the original I could think of, but please be aware that I'm anything but an jaxpert.Output