I'm training a model for seq2Seq using tensorflow. correct me if I'm wrong. I understood that the tf.train.Checkpoint is used to save just the checkpoint files which are only useful when source code that will use the saved parameter values is available. i would like to know how i could instatiate my model later on and load the trained weights from checkpoint in order to test it.
checkpoint_dir = 'training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
here is the code for training:
EPOCHS = 20
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} loss {}'.format(epoch + 1,batch, batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
regards
Here is a proposed answer which suggests to use checkpoint manager.
Ref - https://www.tensorflow.org/guide/checkpoint#train_and_checkpoint_the_model