Accumulation in JAX

1.3k Views Asked by At

What is the best method for handling memory when compiling an accumulation in JAX, such as jax.lax.scan, where a full buffer is excessive?

The following is a geometric progression example. The temptation is to recognise the accumulation only depends on an input size and implement accordingly

import jax.numpy as jnp
import jax.lax as lax

def calc_gp_size(size,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,None,length=size-1)
    return jnp.concatenate((x0[None],x))

jax.config.update("jax_enable_x64", True)

size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')

jax.jit(calc_gp_size)(size,x0,a)

However, attempting to use jax.jit will predictably result in a ConcretizationTypeError.

The correct way is to pass an argument where the buffer already exists.

def calc_gp_array(array,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,array)
    return jnp.concatenate((x0[None],x))

array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')  

jax.jit(calc_gp_array)(array,x0,a)

My concern is that there is a lot of allocated memory not being utilised (or is it?). Is there a more memory efficient approach to this example, or is the allocated memory being used somehow?

EDIT: Incorporating the comments of @jakevdp, treating the function as main (single call - include compile and exclude caching), and profiling resulted it

%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB

%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB

%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB

Less granular results would require line profiling the jit code (not sure how this could be done).

Sequentially initialising the array and then calling jax.jit appears to save memory

%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB

%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB
1

There are 1 best solutions below

2
On BEST ANSWER

The first version will work if you mark the size argument as static and pass a hashable value:

import jax
import jax.numpy as jnp
import jax.lax as lax

def calc_gp_size(size,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,None,length=size-1)
    return jnp.concatenate((x0[None],x))

jax.config.update("jax_enable_x64", True)

size = 2 ** 26
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')

jax.jit(calc_gp_size, static_argnums=0)(size,x0,a)
# DeviceArray([1.        , 1.00000001, 1.00000002, ..., 1.95636587,
#              1.95636589, 1.95636591], dtype=float64)

I think this may be slightly more memory efficient than pre-allocating the array as in your second example, though it would be worth benchmarking if that's important.

Also, if you're doing this sort of operation on GPU, you'll likely find built-in accumulations like jnp.cumprod to be much more performant. I believe this is more or less equivalent to your scan-based function:

result = jnp.cumprod(jnp.full(size, 1 + 1E-8, dtype='f8').at[0].set(1))