ValueError when trying to reconfigure Siamese network trained on triplets to operate on pairs

63 Views Asked by At

My problem

I am trying to implement a Siamese network in Keras that is trained on triplets—[anchor_input, similar_input, different_input]—and executes on pairs—[input_a, input_b]—telling me whether they are similar or different. I am able to train fine on triplets, and have even trained and tested on pairs, but when I train on triplets and try to create my pair-wise network, I get the following error:

ValueError: logits and labels must have the same shape (() vs (?, ?))

Current network overview

My triplet network is defined using the following code, which I have pared down to something pretty minimal for this example:

def siamese_triplet(in_shape, feature_dim):
    input_a = Input(shape=in_shape)  # input is a series, n_samples by n_features
    input_b = Input(shape=in_shape)
    input_c = Input(shape=in_shape)
    base_network = create_model()  # makes my Siamese kernel network
    processed_a = base_network(input_a)  # base_network outputs a vector, say 50 elements long
    processed_b = base_network(input_b)
    processed_c = base_network(input_c)
    l1_distance = lambda x: K.abs(x[0] - x[1])  # vector, 50 elements long
    p_distance = Lambda(l1_distance,
                        output_shape=lambda x: x[0])([processed_a, processed_b])
    n_distance = Lambda(l1_distance,
                        output_shape=lambda x: x[0])([processed_a, processed_c])
    triplet_loss = Lambda(lambda x: K.mean(K.maximum(0, x[0] - x[1] + 1)),
                          output_shape=(1,))([p_distance, n_distance])  # scalar
    model = Model([input_a, input_b, input_c], triplet_loss)
    optimizer = SGD()
    model.compile(optimizer=optimizer,
                  loss=lambda x, y: y)  # passes triplet loss through

I train the network and get a fitted model object out. I then try to re-create the network with a structure built around input pairs, passing it the base_network layer extracted from my model, which ends up being model.layers[3]:

# triplet_net = siamese_triplet(...)
# model = triplet_net.fit(...)
pair_net = siamese_pair(input_shape, model.layers[3])

with siamese_pair defined as:

def siamese_pair(in_shape, base_network):
    input_a = Input(shape=in_shape)
    input_b = Input(shape=in_shape)
    processed_a = base_network(input_a)  # vector, 50 elements
    processed_b = base_network(input_b)
    distance = Lambda(lambda x: K.abs(x[0] - x[1]),
                      output_shape=lambda x: x[0])([processed_a, processed_b])  # vector, 50 elements
    prediction = Lambda(lambda x: K.mean(x),
                        output_shape=(1,))(distance)  # scalar
    model = Model([input_a, input_b], prediction)
    optimizer = SGD()
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy')  # for evaluation purposes

The model.compile(...) line throws the error.

Note that the triplet loss should push distances close to zero for objects that are the same (class 0), and push them towards 1 for objects that are different (class 1), so setting prediction = K.mean(distance) should be pretty close to one of these class labels, I would think.

My transition from triplet-loss training to pair-wise evaluation seems kind of janky to me, and I would love to figure out the best way to do it, so I am open to suggestions to improve the design. In the mean time, I would be happy just getting past this error so I can at least run and evaluate my performance classifying input pairs as similar or different.

My questions

  1. Why am I getting the error above? It seems like it is expecting no y_true at all for the loss function, which is strange to me.
  2. How do I fix the error above?
  3. Is there a better way to pass my trained base_layer into a different Siamese network structure?
  4. Is there a better way to get pair-wise predictions out of my trained base_layer in a different network structure?

Answers to just the first two would be great, but if somebody has suggestions on the last two as well, I am all ears.

0

There are 0 best solutions below