Saving Gradient in Backward Pass Google-JAX

526 Views Asked by At

I am using JAX to implement a simple neural network (NN) and I want to access and save the gradients from the backward pass for further analysis after the NN ran. I can access and look at the gradients temporarily with the python debugger (as long as I am not using jit). But I want to save all gradients over the whole training process and analyze them after the training is done. I have come up with a rather hacky solution for this using id_tap and a global variable (see the code below). But I was wondering whether there is a better solution which does not violate the functional principles of JAX.

Many thanks!

import jax.numpy as jnp
from jax import grad, jit, vmap, random, custom_vjp
from jax.experimental.host_callback import id_tap

# experimental solution
global_save_list = {'x':[],'w':[],'g':[],'des':[]}
def global_save_func(ctx, des):
    x, w, g = ctx
    global_save_list['x'].append(x)
    global_save_list['w'].append(w)
    global_save_list['g'].append(g)
    global_save_list['des'].append(des)


@custom_vjp
def qmvm(x, w):
    return jnp.dot(x, w)

def qmvm_fwd(x, w):
    return qmvm(x, w), (x, w)

def qmvm_bwd(ctx, g):
    x, w = ctx

    # here I would like to save gradients g - or at least running statistics of them

    # experimental solution with id_tap
    id_tap(global_save_func, ((x, w, g)))

    fwd_grad = jnp.dot(g, w.transpose())
    w_grad = jnp.dot(x, g.transpose())
    
    return fwd_grad, w_grad

qmvm.defvjp(qmvm_fwd, qmvm_bwd)

def run_nn(x, w):
    out = qmvm(x, w)   # 1st MVM
    out = qmvm(out, w) # 2nd MVM
    return out

run_nn_batched = vmap(run_nn)

@jit
def loss(x, w, target):
    out = run_nn_batched(x, w)
    return jnp.sum((out - target)**2)

key = random.PRNGKey(42)
subkey1, subkey2, subkey3 = random.split(key, 3)

A = random.uniform(subkey1, (10, 10, 10), minval = -10, maxval = 10)
B = random.uniform(subkey2, (10, 10, 10), minval = -10, maxval = 10)
C = random.uniform(subkey3, (10, 10, 10), minval = -10, maxval = 10)

for e in range(10):
    gval = grad(loss, argnums = 0)(A, B, C)
    # some type of update rule

# here I would like to access gradients, preferably knowing to which MVM (1st or 2nd) and example they belong

# experimental solution:
print(global_save_list) 
0

There are 0 best solutions below