I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:
- Does vmap need fixed sizes for everything through the function(s) being vmapped?
- Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
- If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
- What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)
The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.
yes –
vmap
, like all JAX transformations, requires any arrays defined in the function to have static shapes.No,
vmap
does notjit
-compile a function by default, although you can always compose both if you wish (e.g.jit(vmap(f))
)As mentioned,
vmap
is unrelated tojit
, but an analogy ofjit
static_argnums
is passingNone
toin_axes
, which will keep the argument unmapped and therefore static within the transformation.This is a difficult question to answer without more detail. I'd suggest opening a new question with more specifics. See How to ask a good question, and in particular try to include a Minimal reproducible example of what you're attempting to do.