All values replaced with pool max in pytorch

204 Views Asked by At

Given z equal to

tensor([[[[0.0908, 0.1286, 0.6942, 0.5161],
          [0.4227, 0.2154, 0.5990, 0.8666],
          [0.3009, 0.2399, 0.1818, 0.7551],
          [0.2396, 0.4485, 0.4027, 0.5303]],
         [[0.8251, 0.7457, 0.2091, 0.7313],
          [0.3823, 0.7351, 0.3823, 0.2072],
          [0.0863, 0.5489, 0.6515, 0.3855],
          [0.5247, 0.8685, 0.6078, 0.6181]]]])

We have that

torch.nn.MaxUnpool2d(2)(*torch.nn.MaxPool2d(2, return_indices=True)(z))
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000],
          [0.4227, 0.0000, 0.0000, 0.8666],
          [0.0000, 0.0000, 0.0000, 0.7551],
          [0.0000, 0.4485, 0.0000, 0.0000]],
         [[0.8251, 0.0000, 0.0000, 0.7313],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.6515, 0.0000],
          [0.0000, 0.8685, 0.0000, 0.0000]]]])

How can I idiomatically compute a similar inverse where instead of all non-max values being zero, all are equal to the max? In this case I want:

tensor([[[[0.4227, 0.4227, 0.8666, 0.8666],
          [0.4227, 0.4227, 0.8666, 0.8666],
          [0.4485, 0.4485, 0.7551, 0.7551],
          [0.4485, 0.4485, 0.7551, 0.7551]],
         [[0.8251, 0.8251, 0.7313, 0.7313],
          [0.8251, 0.8251, 0.7313, 0.7313],
          [0.8685, 0.8685, 0.6515, 0.6515],
          [0.8685, 0.8685, 0.6515, 0.6515]]]])

The solution should work for all kernel sizes, padding, and dilation, and be fast. In this simple case, for example, I could easily compute what I want as a Kroeneker product with torch.nn.MaxPool2d(2)(z), but this does not generalize.

0

There are 0 best solutions below