jax/jaxopt solution for linear programming?

EDIT: I just found ott-jax which looks like it might be what I need, but if possible I'd still like to know what I did wrong with jaxopt below!

Original: I'm trying to solve an optimal transport problem, and after following this great blog post I have a working version in numpy/scipy (comments removed for brevity).

In trying to get a jax version of this working I came across this issue and tried looking at the jaxopt library but have not been able to find an implementation of linprog or linear programming (LP). I believe LP is a subset of quadratic programming which jaxopt does implement, but have not been able to replicate the numpy version successfully. Any idea where I am going wrong or how else I can solve this?

import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
from scipy.optimize import linprog
from scipy.spatial.distance import pdist, squareform
from scipy.special import softmax

jax.config.update('jax_platform_name', 'cpu')

def prep_arrays(x, p, q):
  n, d = x.shape
  C = squareform(pdist(x, metric="sqeuclidean"))
  Ap, Aq = [], []
  z = np.zeros((n, n))
  z[:, 0] = 1

  for i in range(n):
    z = np.roll(z, 1, axis=1)

  A = np.row_stack((Ap, Aq))[:-1]
  b = np.concatenate((p, q))[:-1]

  return n, C, A, b

def demo_wasserstein(x, p, q):
  n, C, A, b = prep_arrays(x, p, q)
  result = linprog(C.ravel(), A_eq=A, b_eq=b)
  T = result.x.reshape((n, n))
  return np.sqrt(np.sum(T * C)), T

def jax_attempt_1(x, p, q):
  n, C, A, b = prep_arrays(x, p, q)
  C, A, b = jnp.array(C), jnp.array(A), jnp.array(b)

  def matvec_Q(params_Q, u):
    del params_Q
    return jnp.zeros_like(u)  # no quadratic term so Q is just 0

  def matvec_A(params_A, u):
    return jnp.dot(params_A, u)

  hyper_params = dict(params_obj=(None, C.ravel()), params_eq=A, params_ineq=(b, b))
  osqp = jaxopt.BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A)
  sol, state = osqp.run(None, **hyper_params)
  T = sol.primal[0].reshape((n, n))
  return np.sqrt(np.sum(T * C)), np.array(T)

def jax_attempt_2(x, p, q):
  n, C, A, b = prep_arrays(x, p, q)
  C, A, b = jnp.array(C), jnp.array(A), jnp.array(b)

  def fun(T, params_obj):
    _, c = params_obj
    return jnp.sum(T * c)

  def matvec_A(params_A, u):
    return jnp.dot(params_A, u)

  # solver = jaxopt.EqualityConstrainedQP(fun=fun, matvec_A=matvec_A)
  solver = jaxopt.OSQP(fun=fun, matvec_A=matvec_A)

  init_T = jnp.zeros((16, 16))
  hyper_params = dict(params_obj=(None, C.ravel()), params_eq=(A, b), params_ineq=None)
  init_params = solver.init_params(init_T.ravel(), **hyper_params)
  sol, state = solver.run(init_params=init_params, **hyper_params)
  T = sol.primal.reshape((n, n))
  return np.sqrt(np.sum(T * C)), np.array(T)

if __name__ == '__main__':
  n = 16
  q_values = np.random.normal(size=n)
  p = np.full(n, 1. / n)
  q = softmax(q_values)
  x = np.random.uniform(-1., 1., (n, 1))

  dist_numpy, plan_numpy = demo_wasserstein(x, p, q)
  dist_jax_1, plan_jax_1 = jax_attempt_1(x, p, q)
  dist_jax_2, plan_jax_2 = jax_attempt_2(x, p, q)

  print(f'numpy: dist {dist_numpy}, min {plan_numpy.min()}, max {plan_numpy.max()}')
  print(f'jax_1: dist {dist_jax_1}, min {plan_jax_1.min()}, max {plan_jax_1.max()}')
  print(f'jax_2: dist {dist_jax_2}, min {plan_jax_2.min()}, max {plan_jax_2.max()}')

  # numpy: dist 0.18283759367232585, min 0.0, max 0.06250000000000001
  # jax_1: dist nan, min -395690848.0, max 453536128.0
  # jax_2: dist nan, min -461479360.0, max 528943168.0

There are 1 best solutions below


ott-jax is just what I needed. While it uses the Sinkhorn algorithm as default, and is therefore an approximation, it is more than adequate for my needs. I'm sure with config changes I can improve on the performance as well.

from ott.geometry import pointcloud
from ott.solvers import linear

def jax_attempt_3(x, p, q):
  geom = pointcloud.PointCloud(x, cost_fn=None)
  solve_fn = jax.jit(linear.solve)
  ot = solve_fn(geom, p, q)
  T = ot.matrix.T  # transposed to numpy for some reason
  return np.sqrt(np.sum(T * geom.cost_matrix)), np.array(T)

