I would like to implement the SegNet for 3D images (width, height and depth, where depth is not channels). Thus in the decoder part of the network I need the pooling indices. The function K.tf.nn.max_pool_with_argmax() only works for 2D images (width and height). There is a function MaxPooling3D but this only returns the tensor after the pooling operation without the indices.
Does anyone know a solution to this?