Hi why can't I vectorize the condition function to apply for a list of boolean? or is there something else going on here?
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(cond)
vcond(DK)
I was expecting it to give me an array.
Try this:
Output: