Indexing a batch of images using PyTorch tensor from an index image

29 Views Asked by At

Suppose I have a batch of images M in the form of a torch tensor (B, W, H), and an image I of size (W, H) whose pixels are indices.

I want to get an image (W, H) where each pixel come from the corresponding image in the image batch (following the indexing of I).

Example

Given M of shape (3, 4, 8):

tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],

        [[-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.]],

        [[-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.]]])

and I of shape (4, 8):

tensor([[2, 0, 2, 0, 1, 0, 1, 0],
        [2, 2, 1, 0, 0, 2, 1, 0],
        [2, 0, 0, 2, 1, 1, 0, 0],
        [0, 1, 0, 0, 2, 0, 2, 1]], dtype=torch.int32)

the resulting image would be:

tensor([[-2.,  0., -2.,  0., -1.,  0., -1.,  0.],
        [-2., -2., -1.,  0.,  0., -2., -1.,  0.],
        [-2.,  0.,  0., -2., -1., -1.,  0.,  0.],
        [ 0., -1.,  0.,  0., -2.,  0., -2., -1.]])

Note 1

I don't care about the ordering of the M dimensions, it could be (W, H, B) as well if it provides an easier solution.

Note 2

I am also interested in a NumPy solution.

1

There are 1 best solutions below

0
arthur.sw On

One solution would be:

indices = torch.meshgrid(torch.arange(I.shape[0]), torch.arange(I.shape[1]))
result = M[I, *indices]

or using numpy:

indices = np.indices(I)
result = M[I, indices[0], indices[1]]