I'm using Hydra for training machine learning models. It's great for doing complex commands like python train.py data=MNIST batch_size=64 loss=l2
. However, if I want to then run the trained model with the same parameters, I have to do something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml
. I then use argparse
to load in the previous yaml and use the compose API to initialize the Hydra environment. The path to the trained model is inferred from the path to Hydra's .yaml
file. If I want to modify one of the parameters, I have to add additional argparse
parameters and run something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128
. The code then manually overrides any Hydra parameters with those that were specified on the command line.
What's the right way of doing this?
My current code looks something like the following:
train.py
:
import hydra
@hydra.main(config_name="config", config_path="conf")
def main(cfg):
# [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
# [code outputs model checkpoint to job folder generated by Hydra]
main()
reconstruct.py
:
import argparse
import os
from hydra.experimental import initialize, compose
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('hydra_config')
parser.add_argument('--batch_size', type=int)
# [other flags and parameters I may need to override]
args = parser.parse_args()
# Create the Hydra environment.
initialize()
cfg = compose(config_name=args.hydra_config)
# Since checkpoints are stored next to the .hydra, we manually generate the path.
checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))
# Manually override any parameters which can be changed on the command line.
batch_size = args.batch_size if args.batch_size else cfg.data.batch_size
# [code which uses checkpoint_dir to load the model]
# [code which uses both batch_size and params in cfg to set up the data etc.]
This is my first time posting, so let me know if I should clarify anything.
If you want to load the previous config as is and not change it, use
OmegaConf.load(file_path)
.If you want to re-compose the config (and it sounds like you do, because you added that you want override things), I recommend that you use the Compose API and pass in parameters from the overrides file in the job output directory (next to the stored config.yaml), but concatenate the current run parameters.
This script seems to be doing the job: