I have seen a few posts on restoring TF
models and the Google
doc page on exporting graphs but I think I am missing something.
I use the code in this Gist to save the model along with this utils file to which defines the model
Now I would like to restore it and run in a previously unseen test data as follows:
def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
total_loss = 0
sess = tf.get_default_session()
acc_steps = len(X_data) // BATCH_SIZE
for i in range(acc_steps):
batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)
loss, accuracy = sess.run([loss_value, acc], feed_dict={
images_placeholder: batch_x,
labels_placeholder: batch_y,
keep_prob: 0.5
})
total_accuracy += (accuracy * len(batch_x))
total_loss += (loss * len(batch_x))
return (total_accuracy / num_examples, total_loss / num_examples)
## re-execute the code that defines the model
# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')
gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')
gray /= 255.
# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')
# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')
# construct model
logits = inference(gray, keep_prob)
# calculate loss
loss_value = loss(logits, labels_placeholder)
# training
train_op = training(loss_value, 0.001)
# accuracy
acc = accuracy(logits, labels_placeholder)
with tf.Session() as sess:
loader = tf.train.import_meta_graph('gtsd.meta')
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())
test_accuracy = evaluate(X_test, y_test)
print("Test Accuracy = {:.3f}".format(test_accuracy[0]))
I'm getting a test accuracy of only 3%. However If I don't close the Notebook and run the test code immediately after training the model, I get a 95% accuracy.
This leads me to believe I'm not loading the model correctly?
The problem arises from these two lines:
The first line loads the saved model from a checkpoint. The second line re-initializes all of the variables in the model (such as the weight matrices, convolutional filters, and bias vectors), usually to random numbers, and overwrites the loaded values.
The solution is simple: delete the second line (
sess.run(tf.initialize_all_variables())
) and evaluation will proceed with the trained values loaded from the checkpoint.PS. There is a small chance that this change will give you an error about "uninitialized variables". In that case, you should execute
sess.run(tf.initialize_all_variables())
to initialize any variables not saved in the checkpoint before executingloader.restore(sess, tf.train.latest_checkpoint('./'))
.