Slice array along axis with list of different indices

71 Views Asked by At

I have a 3-dimensional array/tensor of shape (a, b, c), and I have a list of length a of different indices, each in the range [0, b). I want to use the indices to get an array of size (a, c). Right now I do this with an ugly list comprehension

z = torch.stack([t_[b, :] for t_, b in zip(tensor, B)])

This is implemented in a forward pass for a neural network, so I really want to avoid a list comprehension. Is there any torch (or numpy) function that does what I want more efficient?

Also a small example:

tensor = [[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],
          [[ 6,  7],
           [ 8,  9],
           [10, 11]],
          [[12, 13],
           [14, 15],
           [16, 17]],
          [[18, 19],
           [20, 21],
           [22, 23]]]  # shape: (4, 3, 2)
B = [0, 1, 2, 2]
output = [[ 0,  1],
          [ 8,  9],
          [16, 17],
          [22, 23]]  # shape (4, 2)

Background: I have time series data which has time windows of different lengths. I use torch's pack_padded_sequence (and reverse) to mask it, but I have to get the output of the LSTM at the time step before the masking starts, because then the output of the network gets unusable. In the example, I would have 4 time steps with length 0, 1, 2, 2 each with 2 features.

1

There are 1 best solutions below

1
cottontail On BEST ANSWER

Use advanced indexing. To get the desired output, we need the corresponding indices for the first axis, which is created using torch.arange() below:

output = tensor[torch.arange(len(B)), B]

or using numpy

output = tensor[np.arange(len(B)), B]

both produce:

tensor([[ 0,  1],
        [ 8,  9],
        [16, 17],
        [22, 23]])

Full code using example:

import torch
tensor = torch.tensor([
    [[ 0,  1],
     [ 2,  3],
     [ 4,  5]],
    [[ 6,  7],
     [ 8,  9],
     [10, 11]],
    [[12, 13],
     [14, 15],
     [16, 17]],
    [[18, 19],
     [20, 21],
     [22, 23]]])
B = [0, 1, 2, 2]
output = tensor[torch.arange(len(B)), B]