How to save base network of a siamese model?

259 Views Asked by At

I am trying to build a siamese model with a rather complex base network. After building the base network, I use the following code to build my siamese network:

base_network=create_base_model(0.2)
img1=Input(shape=(256,256,3))
img2=Input(shape=(256,256,3))
text_input1 = Input(shape=(), dtype=tf.string, name='text_1')
text_input2 = Input(shape=(), dtype=tf.string, name='text_2')
output1= base_network([img1, text_input1])
output2= base_network([img2, text_input2])
distance = Lambda(euclidean_distance)([output1, output2])
siamese_model = Model([[img1,text_input1], [img2, text_input2]], distance)

The output of the base network is of the form model where

model=Model(inputs=[input1,input2], outputs=[z])

The issue is that after training the siamese network, I want to use the output of the base network as an embedding so I can run unsupervised learning algorithms. However, when training the siamese network, I want to train it for 10 epochs at a time, then save it and continue training if needed. In this scenario, I am not sure how to save/access the base network when I save and reload the Siamese model. For example, I get the following plot for the siamese model which requires 2 inputs (my base model uses 2 inputs so technically I have 4 inputs as shown in the diagram), but I want to use the base model which requires only 1 input post training (technically 2 as my base model uses 2).

Can anyone give me advice on how to load the updated base model using the saved siamese model, or if there's a better approach saving it in the first place?

Thanks very much.

enter image description here

1

There are 1 best solutions below

2
On
if epoch %5 == 0 
   path = f'/tmp/model{epoch}.h5'
   base_network.save(path)

base_network = tf.keras.models.load_model(path)

Isn't that okay?