Vanishing parameters in MAML JAX (Meta Learning)

91 Views Asked by At

I am working on an implementation of MAML (see https://arxiv.org/pdf/1703.03400.pdf) in Jax.

When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while to converge but ultimately works).

However when training on a tasks distributed like A * sin(B + X) where A, B are random variables all the weights in the network converge to 0. training results

This is clearly not right. Thanks in advance for any help provided.

Full code here https://colab.research.google.com/drive/1YoOkwo5tI42LeIbBOxpImkN55Kg9wScl?usp=sharing or see below for minimal code.

Task Generation code:

class MAMLDataLoader:
    def __init__(self, sample_task_fn, num_tasks, batch_size):
        self.sample_task_fn = sample_task_fn
        self.num_tasks = num_tasks
        self.batch_size = batch_size

    def sample_tasks(self, key):
        XS = jnp.empty((self.num_tasks, 2 * self.batch_size, 1))
        YS = jnp.empty((self.num_tasks, 2 * self.batch_size, 1))

        for i in range(self.num_tasks):
            key, subkey = random.split(key)
            xs, ys = self.sample_task_fn(self.batch_size * 2, subkey)
            XS = XS.at[i].set(xs)
            YS = YS.at[i].set(ys)

        x_train, x_test = XS[:, :self.batch_size], XS[:, self.batch_size:]
        y_train, y_test = YS[:, :self.batch_size], YS[:, self.batch_size:]

        return x_train, y_train, x_test, y_test

    def dummy_input(self):
        key = random.PRNGKey(0)
        x = self.sample_task_fn(1, key)[0][0]
        return x

def sample_sinusoidal_task(samples, key):
    # y = a * sin(b + x)
    xs_key, amplitude_key, phase_key = random.split(key, num=3)
    amplitude = random.uniform(amplitude_key, (1, 1))
    phase = random.uniform(phase_key, (1, 1)) * jnp.pi * 2

    xs = (random.uniform(xs_key, (samples, 1)) * 4 - 2) * jnp.pi
    ys = amplitude * jnp.sin(xs + phase)
    return xs, ys

Here is the main MAML code:

class MAMLTrainer:
    def __init__(self, model, alpha, optimiser, inner_steps=1):
        self.model = model
        self.alpha = alpha
        self.optimiser = optimiser
        self.inner_steps = inner_steps

        self.jit_step = jit(self.step)

    def loss(self, params, x, y):
        preds = self.model.apply(params, x)
        return jnp.mean(jnp.inner(y - preds, y - preds) / 2.0)

    def update(self, params, x, y, inner_steps=None):
        if inner_steps is None:
            inner_steps = self.inner_steps

        loss_grad = grad(self.loss)

        def _update(i, params):
            grads = loss_grad(params, x, y)
            new_params = tree_map(lambda p, g: p - self.alpha * g, params, grads)
            return new_params

        return lax.fori_loop(0, inner_steps, _update, params)

    def meta_loss(self, params, x1, y1, x2, y2):
        return self.loss(self.update(params, x1, x2), x2, y2) 

    def batch_meta_loss(self, params, x1, y1, x2, y2):
        return jnp.mean(vmap(partial(self.meta_loss, params))(x1, y1, x2, y2))

    def step(self, params, optimiser, x1, y1, x2, y2):
        loss, grads = value_and_grad(self.batch_meta_loss)(params, x1, y1, x2, y2)

        updates, opt_state = self.optimiser.update(grads, optimiser, params)
        params = optax.apply_updates(params, updates)

        return params, loss


    def train(self, dataloader, steps, key, params=None):
        if params is None:
            key, subkey = random.split(key)
            params = self.model.init(subkey, dataloader.dummy_input())

        optimiser = self.optimiser.init(params)

        pbar, losses = tqdm(range(steps), desc='Training'), []

        for epoch in pbar:
            key, subkey = random.split(key)
            params, loss = self.jit_step(params, optimiser, *dataloader.sample_tasks(subkey))
            losses.append(loss)

            if epoch % 100 == 0:
                avg_loss = jnp.mean(jnp.array(losses[-100:]))

            pbar.set_postfix_str(f'current_loss: {loss:.3f}, running_loss_100_epochs: {avg_loss:.3f}')

        return params, jnp.array(losses)

    def n_shot_learn(self, x_train, y_train, params, n):
        return self.update(params, x_train, y_train, n)

Training Code:

class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features[:-1]):
            x = nn.Dense(feat)(x)
            x = nn.relu(x)

        return nn.Dense(self.features[-1])(x)

model = SimpleMLP([64, 64, 1])
optimiser = optax.adam(1e-3)
trainer = MAMLTrainer(model, 0.1, optimiser, 1)
dataloader = MAMLDataLoader(sample_sinusoidal_task, 2, 100)
key = random.PRNGKey(0)
params, losses = trainer.train(dataloader, 10000, key)
0

There are 0 best solutions below