I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of the CPU cores in a single node as follows:
# start pool process
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes
# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)
# close pool processes to free memory
pool.close()
pool.join()
I know that JAX has vmap and pmap, but I don't understand if either of those are a drop-in replacement for how I'm using multiprocessing.pool.map above.
- Is
vmap(function,in_axes=0)(inputs)
distributing to all available CPU cores or what? - How is
pmap(function,in_axes=0)(inputs)
different from vmap and multiprocessing.pool.map? - Is my usage of multiprocessing.pool.map above an example of a "single-program, multiple-data (SPMD)" code that pmap is meant for?
- When I actually do
pmap(function,in_axes=0)(inputs)
I get an error -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1) -- what does this mean? - Finally, my use case is very simple: I merely want to use some/all of the CPU cores on a single node (e.g., all 10 CPU cores on my Macbook). But I have heard about nesting pmap(vmap) -- is this used to parallelize over the cores of multiple connected nodes (say on a supercomputer)? This would be more akin to mpi4py rather than multiprocessing (the latter is restricted to a single node).
No,
vmap
has nothing to do with parallelization. It is a vectorizing transformation, not a parallelizing transformation. In the course of normal operation, JAX may use multiple cores via XLA, so vmapped operations may also do this. But there's no explicit parallelization invmap
.pmap
parallelizes over multiple XLA devices.vmap
does not parallelize, but rather vectorizes on a single device.multiprocessing
parallelizes over multiple Python processes.Yes, it could be described as SPMD across multiple python processes.
pmap
parallelizes over multiple XLA devices, and you have configured only a single XLA device, so the requested operation is not possible.Yes, I believe that
pmap
can be used to compute on multiple CPU cores. Whether it's nested withvmap
is irrelevant. See JAX pmap with multi-core CPU.Note also that
jax.pmap
is deprecated in favor of the newerjax.shard_map
, which is a much more flexible transformation for multi-device/multi-host computation. There's some info here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html