Selective output branch training in Keras

44 Views Asked by At

I am using a Functional API in Keras like this: U-Net, but with branched

# Encoder
input_layer = Input(shape=(cfg.MODEL_WIDTH, cfg.MODEL_HEIGHT, 3),
                    name="real_patch_input", batch_size=cfg.BATCH_SIZE)

x_8_pooled, x_16, x_32, x_64, x_128 = build_united_encoder(input_layer)

# Branches (decoders)
output_layers = []
for i in range(FEATURE_MAPS):
    output_layers.append(build_branch_decoder(x_8_pooled, x_16, x_32, x_64, x_128, i))

# Return the final model
model = Model(inputs=input_layer, outputs=output_layers)
return model

I try to train each decoder for each feature, with mutual encoder, where the number of features is FEATURE_MAPS. Output of each decoder is sigmoid 1x1 conv. For each image during training, i generate the image and feature. I would like to selectively train each branch. But here's the thing, I only have dataset, where i know location of only one keypoint present in the image, while others may be present in the image, but I have no information about this. I would therefore like to have FEATURE_MAPS number of loss functions, which i managed to do, but:

But I face one big issue:

  • The batch, of course, contains a mix of different keypoint channels (0 to FEATURE_MAPS - 1). Since each batch entry would affect different branch, I can't come up with any solution that wouldn't inherently introduce small batch size. Small batch size for practical purposses of training is something I'd like to avoid. I am now trying to solve this with custom training loop, but the small batch size seems unavoidable.
  • Also, of course, it makes sense to update the branch only for the respective inputs.

Is there any way to handle this right?

I tried setting up the custom training loop, but with no idea how to solve said conceptual issues.

0

There are 0 best solutions below