Iterate over a tensor's rows and cols in Tensorflow

422 Views Asked by At

A part of my project is to use a thresholding kernel on an image. The thresholding kernel could look like this:

[50  100]
[150 200]

I would like to go over each group of 3x3 pixels (without overlap), and threshold them using my kernel.

For example, if I have this grayscale image:

[120 120 120 120]
[120 120 120 120]
[170 170 170 170]
[170 170 170 170]

Then after thresholding I should get this image:

[1 1 1 1]
[0 0 0 0]
[1 1 1 1]
[1 0 1 0]

I am using TensorFlow, and my network is needed with different shapes of batches.

The input is:

data['input_tensor'] = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='Input')

The thresholding kernel is of type: tf.Variable(), and size 5x5, AND ITS VALUES SHOULD BE LEARNED!

I can't find a way to make it happen. I tried iterating through the input batch, but its size is unknown (it is only known during a session). I don't want to duplicate the thresholding kernel, because then the network will try to learn all of its values (and it's too much for now).

Is there a way to do it without loops? If not, how could I do it with loops?

Thanks.

0

There are 0 best solutions below