tensorflow unique function with 2d?

43 Views Asked by At

tf.unique currently only works on 1D tensors. How can I find index of unique values in a 2D tensor.

input = tf.constant([[0,0,1,1,1,8], [2,2,5,5,5,2], [6,6,6,8,8,9]])
#output should be  = [[0,0,1,1,1,2], [0,0,1,1,1,0], [0,0,0,1,1,2]]
1

There are 1 best solutions below

0
Luca Anzalone On

You can loop over 1D tensors that made up your 2D tensor to get all the indices of the unique values:

indices = []

for value in input:
   idx = tf.unique(value).idx
   indices.append(idx)

print(indices)
[<tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 1, 1, 1, 2])>,
 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 1, 1, 1, 0])>,
 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([0, 0, 0, 1, 1, 2])>]

Lastly, if you need to you can tf.concat(indices, axis=0) or tf.stack(indices) them.