FastGAN - RuntimeError: Error(s) in loading state_dict for Generator

1.7k Views Asked by At

I'm running FastGAN (https://github.com/odegeasslbc/FastGAN-pytorch) on Google Colab and now trying to resume training from a saved .pth generated by the network. However, it keeps throwing this error:

Traceback (most recent call last):
  File "train.py", line 202, in 
    train(args)
  File "train.py", line 117, in train
    netG.load_state_dict(ckpt['g'])
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "init.init.0.weight_orig", "init.init.0.weight", "init.init.0.weight_u", "init.init.0.weight_orig", "init.init.0.weight_u", "init.init.0.weight_v", "init.init.1.weight", "init.init.1.bias", "init.init.1.running_mean", "init.init.1.running_var", "feat_8.1.weight_orig", "feat_8.1.weight", "feat_8.1.weight_u", "feat_8.1.weight_orig", "feat_8.1.weight_u", "feat_8.1.weight_v", "feat_8.2.weight", "feat_8.3.weight", "feat_8.3.bias", "feat_8.3.running_mean", "feat_8.3.running_var", "feat_8.5.weight_orig", "feat_8.5.weight", "feat_8.5.weight_u", "feat_8.5.weight_orig", "feat_8.5.weight_u", "feat_8.5.weight_v", "feat_8.6.weight", "feat_8.7.weight", "feat_8.7.bias", "feat_8.7.running_mean", "feat_8.7.running_var", "feat_16.1.weight_orig", "feat_16.1.weight", "feat_16.1.weight_u", "feat_16.1.weight_orig", "feat_16.1.weight_u", "feat_16.1.weight_v", "feat_16.2.weight", "feat_16.2.bias", "feat_16.2.running_mean", "feat_16.2.running_var", "feat_32.1.weight_orig", "feat_32.1.weight", "feat_32.1.weight_u", "feat_32.1.weight_orig", "feat_32.1.weight_u", "feat_32.1.weight_v", "feat_32.2.weight", "feat_32.3.weight", "feat_32.3.bias", "feat_32.3.running_mean", "feat_32.3.running_var", "feat_32.5.weight_orig", "feat_32.5.weight", "feat_32.5.weight_u", "feat_32.5.weight_orig", "feat_32.5.weight_u", "feat_32.5.weight_v", "feat_32.6.weight", "feat_32.7.weight", "feat_32.7.bias", "feat_32.7.running_mean", "feat_32.7.running_var", "feat_64.1.weight_orig", "feat_64.1.weight", "feat_64.1.weight_u", "feat_64.1.weight_orig", "feat_64.1.weight_u", "feat_64.1.weight_v", "feat_64.2.weight", "feat_64.2.bias", "feat_64.2.running_mean", "feat_64.2.running_var", "feat_128.1.weight_orig", "feat_128.1.weight", "feat_128.1.weight_u", "feat_128.1.weight_orig", "feat_128.1.weight_u", "feat_128.1.weight_v", "feat_128.2.weight", "feat_128.3.weight", "feat_128.3.bias", "feat_128.3.running_mean", "feat_128.3.running_var", "feat_128.5.weight_orig", "feat_128.5.weight", "feat_128.5.weight_u", "feat_128.5.weight_orig", "feat_128.5.weight_u", "feat_128.5.weight_v", "feat_128.6.weight", "feat_128.7.weight", "feat_128.7.bias", "feat_128.7.running_mean", "feat_128.7.running_var", "feat_256.1.weight_orig", "feat_256.1.weight", "feat_256.1.weight_u", "feat_256.1.weight_orig", "feat_256.1.weight_u", "feat_256.1.weight_v", "feat_256.2.weight", "feat_256.2.bias", "feat_256.2.running_mean", "feat_256.2.running_var", "se_64.main.1.weight_orig", "se_64.main.1.weight", "se_64.main.1.weight_u", "se_64.main.1.weight_orig", "se_64.main.1.weight_u", "se_64.main.1.weight_v", "se_64.main.3.weight_orig", "se_64.main.3.weight", "se_64.main.3.weight_u", "se_64.main.3.weight_orig", "se_64.main.3.weight_u", "se_64.main.3.weight_v", "se_128.main.1.weight_orig", "se_128.main.1.weight", "se_128.main.1.weight_u", "se_128.main.1.weight_orig", "se_128.main.1.weight_u", "se_128.main.1.weight_v", "se_128.main.3.weight_orig", "se_128.main.3.weight", "se_128.main.3.weight_u", "se_128.main.3.weight_orig", "se_128.main.3.weight_u", "se_128.main.3.weight_v", "se_256.main.1.weight_orig", "se_256.main.1.weight", "se_256.main.1.weight_u", "se_256.main.1.weight_orig", "se_256.main.1.weight_u", "se_256.main.1.weight_v", "se_256.main.3.weight_orig", "se_256.main.3.weight", "se_256.main.3.weight_u", "se_256.main.3.weight_orig", "se_256.main.3.weight_u", "se_256.main.3.weight_v", "to_128.weight_orig", "to_128.weight", "to_128.weight_u", "to_128.weight_orig", "to_128.weight_u", "to_128.weight_v", "to_big.weight_orig", "to_big.weight", "to_big.weight_u", "to_big.weight_orig", "to_big.weight_u", "to_big.weight_v", "feat_512.1.weight_orig", "feat_512.1.weight", "feat_512.1.weight_u", "feat_512.1.weight_orig", "feat_512.1.weight_u", "feat_512.1.weight_v", "feat_512.2.weight", "feat_512.3.weight", "feat_512.3.bias", "feat_512.3.running_mean", "feat_512.3.running_var", "feat_512.5.weight_orig", "feat_512.5.weight", "feat_512.5.weight_u", "feat_512.5.weight_orig", "feat_512.5.weight_u", "feat_512.5.weight_v", "feat_512.6.weight", "feat_512.7.weight", "feat_512.7.bias", "feat_512.7.running_mean", "feat_512.7.running_var", "se_512.main.1.weight_orig", "se_512.main.1.weight", "se_512.main.1.weight_u", "se_512.main.1.weight_orig", "se_512.main.1.weight_u", "se_512.main.1.weight_v", "se_512.main.3.weight_orig", "se_512.main.3.weight", "se_512.main.3.weight_u", "se_512.main.3.weight_orig", "se_512.main.3.weight_u", "se_512.main.3.weight_v", "feat_1024.1.weight_orig", "feat_1024.1.weight", "feat_1024.1.weight_u", "feat_1024.1.weight_orig", "feat_1024.1.weight_u", "feat_1024.1.weight_v", "feat_1024.2.weight", "feat_1024.2.bias", "feat_1024.2.running_mean", "feat_1024.2.running_var". 
    Unexpected key(s) in state_dict: "module.init.init.0.weight_orig", "module.init.init.0.weight_u", "module.init.init.0.weight_v", "module.init.init.1.weight", "module.init.init.1.bias", "module.init.init.1.running_mean", "module.init.init.1.running_var", "module.init.init.1.num_batches_tracked", "module.feat_8.1.weight_orig", "module.feat_8.1.weight_u", "module.feat_8.1.weight_v", "module.feat_8.2.weight", "module.feat_8.3.weight", "module.feat_8.3.bias", "module.feat_8.3.running_mean", "module.feat_8.3.running_var", "module.feat_8.3.num_batches_tracked", "module.feat_8.5.weight_orig", "module.feat_8.5.weight_u", "module.feat_8.5.weight_v", "module.feat_8.6.weight", "module.feat_8.7.weight", "module.feat_8.7.bias", "module.feat_8.7.running_mean", "module.feat_8.7.running_var", "module.feat_8.7.num_batches_tracked", "module.feat_16.1.weight_orig", "module.feat_16.1.weight_u", "module.feat_16.1.weight_v", "module.feat_16.2.weight", "module.feat_16.2.bias", "module.feat_16.2.running_mean", "module.feat_16.2.running_var", "module.feat_16.2.num_batches_tracked", "module.feat_32.1.weight_orig", "module.feat_32.1.weight_u", "module.feat_32.1.weight_v", "module.feat_32.2.weight", "module.feat_32.3.weight", "module.feat_32.3.bias", "module.feat_32.3.running_mean", "module.feat_32.3.running_var", "module.feat_32.3.num_batches_tracked", "module.feat_32.5.weight_orig", "module.feat_32.5.weight_u", "module.feat_32.5.weight_v", "module.feat_32.6.weight", "module.feat_32.7.weight", "module.feat_32.7.bias", "module.feat_32.7.running_mean", "module.feat_32.7.running_var", "module.feat_32.7.num_batches_tracked", "module.feat_64.1.weight_orig", "module.feat_64.1.weight_u", "module.feat_64.1.weight_v", "module.feat_64.2.weight", "module.feat_64.2.bias", "module.feat_64.2.running_mean", "module.feat_64.2.running_var", "module.feat_64.2.num_batches_tracked", "module.feat_128.1.weight_orig", "module.feat_128.1.weight_u", "module.feat_128.1.weight_v", "module.feat_128.2.weight", "module.feat_128.3.weight", "module.feat_128.3.bias", "module.feat_128.3.running_mean", "module.feat_128.3.running_var", "module.feat_128.3.num_batches_tracked", "module.feat_128.5.weight_orig", "module.feat_128.5.weight_u", "module.feat_128.5.weight_v", "module.feat_128.6.weight", "module.feat_128.7.weight", "module.feat_128.7.bias", "module.feat_128.7.running_mean", "module.feat_128.7.running_var", "module.feat_128.7.num_batches_tracked", "module.feat_256.1.weight_orig", "module.feat_256.1.weight_u", "module.feat_256.1.weight_v", "module.feat_256.2.weight", "module.feat_256.2.bias", "module.feat_256.2.running_mean", "module.feat_256.2.running_var", "module.feat_256.2.num_batches_tracked", "module.se_64.main.1.weight_orig", "module.se_64.main.1.weight_u", "module.se_64.main.1.weight_v", "module.se_64.main.3.weight_orig", "module.se_64.main.3.weight_u", "module.se_64.main.3.weight_v", "module.se_128.main.1.weight_orig", "module.se_128.main.1.weight_u", "module.se_128.main.1.weight_v", "module.se_128.main.3.weight_orig", "module.se_128.main.3.weight_u", "module.se_128.main.3.weight_v", "module.se_256.main.1.weight_orig", "module.se_256.main.1.weight_u", "module.se_256.main.1.weight_v", "module.se_256.main.3.weight_orig", "module.se_256.main.3.weight_u", "module.se_256.main.3.weight_v", "module.to_128.weight_orig", "module.to_128.weight_u", "module.to_128.weight_v", "module.to_big.weight_orig", "module.to_big.weight_u", "module.to_big.weight_v", "module.feat_512.1.weight_orig", "module.feat_512.1.weight_u", "module.feat_512.1.weight_v", "module.feat_512.2.weight", "module.feat_512.3.weight", "module.feat_512.3.bias", "module.feat_512.3.running_mean", "module.feat_512.3.running_var", "module.feat_512.3.num_batches_tracked", "module.feat_512.5.weight_orig", "module.feat_512.5.weight_u", "module.feat_512.5.weight_v", "module.feat_512.6.weight", "module.feat_512.7.weight", "module.feat_512.7.bias", "module.feat_512.7.running_mean", "module.feat_512.7.running_var", "module.feat_512.7.num_batches_tracked", "module.se_512.main.1.weight_orig", "module.se_512.main.1.weight_u", "module.se_512.main.1.weight_v", "module.se_512.main.3.weight_orig", "module.se_512.main.3.weight_u", "module.se_512.main.3.weight_v", "module.feat_1024.1.weight_orig", "module.feat_1024.1.weight_u", "module.feat_1024.1.weight_v", "module.feat_1024.2.weight", "module.feat_1024.2.bias", "module.feat_1024.2.running_mean", "module.feat_1024.2.running_var", "module.feat_1024.2.num_batches_tracked". 

Any idea what might be happening here?

Thanks so much for any help!

1

There are 1 best solutions below

3
On BEST ANSWER

This is common when changing the attribute name of the submodules in your nn.Module.

Notice how most of your layer keys here differ from the ones contained in the loaded state dict because of their prefix: all keys in the dictionary have a 'module.' prefix.

A quick fix you be to slice away this prefix. You could for instance usea dict comprehension:

loaded_state = {k.replace('module.', ''): v for k, v in ckpt['g'].items()}
netG.load_state_dict(loaded_state)