I have a family of functions parameterized by args
f(x, args)
and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My first attempt was to loop through the different values of args and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap inside a jax.scipy.optimize.minimize or jaxopt.ScipyMinimize, but I can't seem to pass more than one value for args.
Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.
I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args are not changing.
You can define a function to find the minimum given particular
args, and then wrap it injax.vmapto automatically vectorize it. For example: