Pytorch setting elements to zero with "tensor index"

278 Views Asked by At

I've used Pytorch for a few months. But I recently want to create a customized pooling layer which is similar to the "Max-Pooling Dropout" layer and I think Pytorch provided us a bunch of tools to build such a layer I need. Here is my approach:

  1. use MaxPool2d with indices returned
  2. set tensor[indices] to zero
  3. I want it behaves like torch.take (without flatten) if possible.

here is how to get the "index tensor". (I think it is called "index tensor". correct me if I was wrong)

input1 = torch.randn(1, 1, 6, 6)
m = nn.MaxPool2d(2,2, return_indices=True)
val, indx = m(input1)

indx is the "index tensor" which can be used easily as

torch.take(input1, indx)

No flatten needed, no argument needed to set dimension. I think it make sense since indx is generated from input1.

Question: how do I set the values input1 pointed by indx to 0 in the "torch.take" style? I saw some answers like Indexing a multi-dimensional tensor with a tensor in PyTorch. But I don't think FB returning such "index tensor" thing which cannot be applied directly. (Maybe I was wrong.)

Is there something like

torch.set_value(input1, indx, 0) ?
0

There are 0 best solutions below