Multi-class segmentation in Keras

331 Views Asked by At

I'm trying to implement a multi-class segmentation in Keras:

  • input image is grayscale (i.e 1 channel)
  • ground truth image has 3 channels, each pixel is a one-hot vector of length 3
  • prediction is standard U-Net trained with categorical_crossentropy outputting 3 channels (softmax-ed)

What is wrong with this setup? The training loss has some weird behaviour:

  • in my lucky cases it behaves as expected (decreases)
  • 90 % of the time it's stuck at ~0.9

My implementation can be found here

I don't think there is anything wrong with the code: if my ground truth is 1-channel (i.e 0s everywhere and 1s somewhere) and use binary_crossentropy + sigmoid as final activation I see no weird behaviour.

1

There are 1 best solutions below

0
On

I'll answer my own question. The solution is to weight each class i.e using a weighted cross entropy loss