JAX/XLA slow compilation using conda

544 Views Asked by At

I'm getting into using Google JAX and the built-in jit and grad functionality. These aspects are working nicely on my machine, but when I increase the number of arguments I get the following notification:

********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_obj_func__1.9055
********************************

I would love to increase the number of input parameters, and so I think soon I will need a faster compile time, so this notification appeals to me... but I don't understand how to implement it.

I've been using conda to install jax. Basically, I run the following commands in the terminal:

    ~$ conda create --name jax
    ~$ conda activate jax
    ~$ conda install -c conda-forge jax matplotlib cudatoolkit

I'm certain there must be a way to add some options when installing in conda (for example, using conda install jax=arguments but I can't find how to do it in the documentation anywhere. There doesn't seem to be anything on stack overflow either — a search only turned up the following: Very slow jit compile for XLA when using jax

Any advice would be greatly appreciated!

0

There are 0 best solutions below