I'm working on optimizing a model in jax that involves fitting a large observational dataset (4800 data points) with a complex model containing interpolation. The current optimization process using jaxopt.ScipyBoundedMinimize takes around 30 seconds for 100 iterations, with most of the time spent seemingly during or before the first iteration starts. You can find the relevant code snippet below. you can find the necessary data for the relevant code at the following link.
necessary data (idc, sg and cpcs)
import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle
file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()
file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()
file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()
def model(fssc, fssh, time, rv, amp):
fssp = 1.0 - (fssc + fssh)
ivis = cpcs['common'][time]['ivis']
areas = cpcs['common'][time]['areas']
mus = cpcs['common'][time]['mus']
vels = idc['vels'].copy()
ldfs_phot = cpcs['line'][time]['ldfs_phot']
ldfs_cool = cpcs['line'][time]['ldfs_cool']
ldfs_hot = cpcs['line'][time]['ldfs_hot']
lps_phot = cpcs['line'][time]['lps_phot']
lps_cool = cpcs['line'][time]['lps_cool']
lps_hot = cpcs['line'][time]['lps_hot']
lis_phot = cpcs['line'][time]['lis_phot']
lis_cool = cpcs['line'][time]['lis_cool']
lis_hot = cpcs['line'][time]['lis_hot']
coeffs_phot = lis_phot * ldfs_phot * areas * mus
wgt_phot = coeffs_phot * fssp[ivis]
wgtn_phot = jnp.sum(wgt_phot)
coeffs_cool = lis_cool * ldfs_cool * areas * mus
wgt_cool = coeffs_cool * fssc[ivis]
wgtn_cool = jnp.sum(wgt_cool)
coeffs_hot = lis_hot * ldfs_hot * areas * mus
wgt_hot = coeffs_hot * fssh[ivis]
wgtn_hot = jnp.sum(wgt_hot)
prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
prf /= wgtn_phot + wgtn_cool + wgtn_hot
prf = jnp.interp(vels, vels + rv, prf)
prf = prf + amp
avg = jnp.mean(prf)
prf = prf / avg
return prf
def loss(x0s, lmbd):
noes = sg['noes']
noo = len(idc['times'])
fssc = x0s[:noes]
fssh = x0s[noes: 2 * noes]
fssp = 1.0 - (fssc + fssh)
rv = x0s[2 * noes: 2 * noes + noo]
amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]
chisq = 0
for i, itime in enumerate(idc['times']):
oprf = idc['data'][itime]['prf']
oprf_errs = idc['data'][itime]['errs']
nop = len(oprf)
sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])
chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)
wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])
mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
(1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']
ftot = chisq + lmbd * mem
return ftot
if __name__ == '__main__':
# idc: a dictionary containing observational data (150 x 32)
# sg and cpcs: dictionaries with related coefficients
noes = sg['noes']
lmbd = 1.0
maxiter = 1000
tol = 1e-5
fss = jnp.ones(2 * noes) * 1e-5
x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))
minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2
bounds = (minx0s, maxx0s)
start = ela_time.time()
optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
options={'disp': True})
x0s, info = optimizer.run(x0s, bounds, lmbd)
# optimizer = optax.adam(learning_rate=0.1)
# optimizer_state = optimizer.init(x0s)
#
# for i in range(1, maxiter + 1):
#
# print('ITERATION -->', i)
#
# gradients = jax.grad(loss)(x0s, lmbd)
# updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
# x0s = optax.apply_updates(x0s, updates)
# x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
# print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))
end = ela_time.time()
print(end - start) # total elapsed time: ~30 seconds
Here's a breakdown of the relevant aspects:
- Number of free parameters (
x0s): 5263 - Data: Observational data stored in
idcdictionary (4800 data points) - Model: Defined in
modelfunction, also utilizes interpolation - Optimization methods tried:
jaxopt.ScipyBoundedMinimizewithL-BFGS-Bmethod (slow ~30 seconds, with most of the time spent during or just before the first iteration)- optax.adam (too slow ~200 seconds)
- Attempted parallelization: I attempted to parallelize
optax.adam, yet due to the inherent nature of the modeling, I couldn't succeed as thex0scouldn't be divided. (assuming I understood parallelization correctly)
Questions:
- What are potential reasons for the slowness before or during the first iteration in
ScipyBoundedMinimize? - Are there alternative optimization algorithms in
jaxthat might be faster for my scenario (large number of free parameters and data points, complex model with interpolation)? - Did I misunderstand parallelization with
optax.adam? Are there any strategies for potential parallelization in this case? - Are there any code optimizations within the provided snippet that could improve performance (e.g., vectorization)?
Additional Information:
- Hardware: Intel® Core™ i7-9750H CPU @ 2.60GHz × 12, 16 GiB RAM (laptop)
- Software: OS Ubuntu 22.04, Python 3.10.12, JAX 0.4.25, optax 0.2.1
I'd appreciate any insights or suggestions to improve the optimization performance.
JAX code is Just-in-time (JIT) compiled, meaning that the long duration of the first step is likely related to compilation costs. The longer your code is, the more time it will take to compile.
One common issue leading to long compile times is the use of Python control flow such as
forloops. JAX's tracing machinery essentially flattens out these loops (see JAX Sharp Bits: Control Flow). In your case, you loop over 4800 entries in your data structure, and thus are creating a very long and inefficient program.The typical solution in a case like this is to rewrite your program using
jax.vmap. Like most JAX constructs, this works best with a struct-of-arrays pattern rather than the array-of-structs pattern used in your data. So the first step to usingvmapis to restructure your data in a way that JAX can use; it might look something like this:This will not work directly: you'll also have to restructure the data used by your
modelfunction into the struct-of-arrays style, but hopefully this gives you the general idea.Note also that this assumes that every entry of
idc['data'][i]['prf']andidc['data'][i]['errs']has the same shape. If that's not the case, then I'm afraid your problem is not particularly well-suited to JAX's SPMD programming model, and there's not an easy way to work around the need for long compilations.