I was going through the omniglot maml example and saw that they have net.train()
at the top of their testing code. This seems like a mistake since that means the stats from each task at meta-testing is shared:
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
however whenever I do eval instead I get that my MAML model diverges (though my test is on mini-imagenet):
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5942, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
eval_loss=0.9859228551387786, eval_acc=0.5907692521810531
args.meta_learner.lr_inner=0.01
==== in forward2
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(171440.6875, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(208426.0156, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(17067344., grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(40371.8125, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.0911e+11, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(21.3515, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(5.4257e+13, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(128.9109, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(3994.7734, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1682896., grad_fn=<NormBackward1>)
eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224
So what is one suppose to do to avoid this divergence?
note:
- retraining is really expensive. Takes 18 days to train a 5cnn with maml for me. A Distributed soln would really help here https://github.com/learnables/learn2learn/issues/170
- perhaps just using train during training (even if evaluating during training might be a good idea so that the batch stats are saved in the checkpoint)
- or next time train stuff with batch stats from the beginning
related:
- https://github.com/facebookresearch/higher/issues/107
- https://discuss.pytorch.org/t/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-higher-library/136022
- How to use have batch norm not forget batch statistics it just used in Pytorch?
- https://discuss.pytorch.org/t/how-does-pytorch-s-batch-norm-know-if-the-forward-pass-its-doing-is-for-inference-or-training/16857/10
- https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du/551153#551153
- https://github.com/tristandeleu/pytorch-maml/issues/19
TLDR: Use
mdl.train()
since that uses batch statistics (but inference will not be deterministic anymore). You probably won't want to usemdl.eval()
in meta-learning.BN intended behaviour:
This is likely why I don't see divergence in my testing with the
mdl.train()
.So just make sure you use
mdl.train()
(since that uses batch statistics https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d) but that either the new running stats that cheat aren't saved or used later.