No module named 'jax.experimental.global_device_array' when running the official Flax Example on Colab with V100

1.9k Views Asked by At

I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python main.py --workdir=./imagenet --config=configs/v100_x8.py , the returned error is

File "/content/FlaxImageNet/main.py", line 29, in <module>
import train
File "/content/FlaxImageNet/train.py", line 30, in <module>
from flax.training import checkpoints
File "/usr/local/lib/python3.10/dist-packages/flax/training/checkpoints.py", line 34, 
in <module>
from jax.experimental.global_device_array import GlobalDeviceArray
ModuleNotFoundError: No module named 'jax.experimental.global_device_array'

I am not sure whether global_device_array has been moved from jax.experimental package or it is no longer needed or replaced by other equivalent methods.

1

There are 1 best solutions below

3
On BEST ANSWER

GlobalDeviceArray was deprecated in JAX version 0.4.1 and removed in JAX version 0.4.7.

With that in mind, it seems the code in question requires JAX version 0.4.6 or older. You might consider reporting this incompatibility to the flax project: http://github.com/google/flax/.