Very slow jit compile for XLA when using jax

2.1k Views Asked by At

I am using Jax to do some machine learning jobs. Jax uses XLA to do some just-in-time compile for acceleration but the compile itself is too slow on CPU. My situation is that the CPU will only use just a single core to do the compile, which is not efficient at all.

I have found some answers that it can be very fast if I can use GPU for the compile. Can anyone tell me how to use GPU to do the compile part? Since I did not do any configuration about the compile. Thanks!

Some addition for the question: I am using Jax to calculate grad and hessian, which would makes the compile very slow. The code is like:

    ## get results from model ##
    def get_model_value(images):
        return jnp.sum(model(images))

    def get_model_grad(images):
        images = jnp.expand_dims(images, axis=0)
        image_grad = jacfwd(get_model_value)(images)
        return image_grad
    
    def get_model_hessian(images):
        images = jnp.expand_dims(images, axis=0)
        image_hess = jacfwd(jacrev(get_model_value))(images)
        return image_hess
  
    # get value
    model_value = model(dis_img)
    FR_value = jnp.expand_dims(FR_value, axis=1)
    value_loss = crit_mse(model_value, FR_value)
    
    # get grad
    vmap_model_grad = jax.vmap(get_model_grad)
    model_grad = vmap_model_grad(dis_img)
    
    # get hessian
    vmap_model_hessian = vmap(get_model_hessian)
    model_hessian = vmap_model_hessian(dis_img)
0

There are 0 best solutions below