Having followed Chris McCormick's tutorial for creating a BERT Fake News Detector (link here), at the end he saves the PyTorch model using the following code:
output_dir = './model_save/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
As he says himself, it can be reloaded using from_pretrained()
. Currently, what the code does is create an output directory with 6 files:
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json
So how can I use the from_pretrained()
method to load the model with all of its arguments and respective weights, and which files do I use from the six?
I understand that a model can be loaded as such (from PyTorch documentation):
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
but how can I make use of the files in the output directory to do this?
Any help is appreciated!
I was able to accomplish this just feeding the model path to the from_pretrained() function. The from_pretrained function was able to identify relevant json config files and load the model. Like this:
Sometimes it helps to just try some code and see if it works.