Load pytorch model with correct args from files

989 Views Asked by At

Having followed Chris McCormick's tutorial for creating a BERT Fake News Detector (link here), at the end he saves the PyTorch model using the following code:

output_dir = './model_save/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
    
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

As he says himself, it can be reloaded using from_pretrained(). Currently, what the code does is create an output directory with 6 files:

config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json

So how can I use the from_pretrained() method to load the model with all of its arguments and respective weights, and which files do I use from the six?

I understand that a model can be loaded as such (from PyTorch documentation):

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

but how can I make use of the files in the output directory to do this?

Any help is appreciated!

1

There are 1 best solutions below

0
On

I was able to accomplish this just feeding the model path to the from_pretrained() function. The from_pretrained function was able to identify relevant json config files and load the model. Like this:

TheModelClass.from_pretrained(output_dir)

Sometimes it helps to just try some code and see if it works.