I am working on a problem that I can solve with small scale neural networks but need to use a newton like update. I have provided a working example below and wanted to know if anyone can give me some advice on speeding up the newton update.
import jax
import equinox as eqx
from jax import numpy as jnp
import matplotlib.pyplot as plt
from jax import flatten_util
key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)
data = jnp.concatenate((jax.random.normal(subkey1, shape=(100, 2)) * 0.1 - 1, jax.random.normal(subkey2, shape=(100, 2)) * 0.1 + 1), axis=0)
labels = jnp.array([0] * 100 + [1] * 100)
plt.scatter(data[:,0], data[:,1])
plt.show()
class MLP(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
self.layers = [
eqx.nn.Linear(2, 10, key=key1),
jax.nn.relu,
eqx.nn.Linear(10, 12, key=key2),
jax.nn.relu,
eqx.nn.Linear(12, 2, key=key3),
jax.nn.log_softmax
]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def loss_fn(model, ins, ytrue):
pred_y = eqx.filter_vmap(model)(ins)
return cross_entropy(ytrue, pred_y)
def cross_entropy(y, pred_y):
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
@eqx.filter_jit
def compute_accuracy(m, x, y):
pred_y = eqx.filter_vmap(m)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
key = jax.random.PRNGKey(42)
@eqx.filter_jit
def step(mlp, xs, ys):
vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
updates = jax.tree_map(lambda g: -0.1 * g, grads)
mlp = eqx.apply_updates(mlp, updates)
return mlp, vals
epochs = 50
batch_size = 100
grad_loss = []
grad_acc = []
key, subkey = jax.random.split(key, 2)
model = MLP(subkey)
for e in range(epochs):
if e % 20 == 0:
print(e, "/", epochs)
key, subkey = jax.random.split(key, 2)
inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
inputs = data[inds]
ls = labels[inds]
model, loss = step(model, inputs, ls)
grad_loss.append(loss)
grad_acc.append(compute_accuracy(model, data, labels))
@eqx.filter_jit
def loss_h(arrs, static, ins, ytrue, uf):
arrs = uf(arrs)
model = eqx.combine(arrs, static)
pred_y = eqx.filter_vmap(model)(ins)
return cross_entropy(ytrue, pred_y)
@eqx.filter_jit
def step_h(mlp, xs, ys):
vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
a, s = eqx.partition(mlp, eqx.is_inexact_array)
flat_a, unflat_a = flatten_util.ravel_pytree(a)
h = jax.hessian(loss_h)(flat_a, s, xs, ys, unflat_a)
g_flat, unflat = flatten_util.ravel_pytree(grads)
updates = unflat(-1 * jnp.linalg.pinv(h) @ g_flat)
mlp = eqx.apply_updates(mlp, updates)
return mlp, vals
key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key, 2)
model_h = MLP(subkey)
epochs = 50
batch_size = 100
h_loss = []
h_acc = []
for e in range(epochs):
if e % 20 == 0:
print(e, "/", epochs)
key, subkey = jax.random.split(key, 2)
inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
inputs = data[inds]
ls = labels[inds]
model_h, loss = step_h(model_h, inputs, ls)
h_loss.append(loss_fn(model_h, data, labels))
h_acc.append(compute_accuracy(model_h, data, labels))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(h_loss, label="Newton")
ax2.plot(h_acc, label="Newton")
ax1.plot(grad_loss, label="Grad")
ax2.plot(grad_acc, label="Grad")
plt.legend()
plt.show()
My interst is really in getting as much speed as I can by optimizing the update part of the code as that is where the script spends most of its time.
Hoping a few experts can help me speed this up a lot.