Given a few block matrices, get the overall large matrix

27 Views Asked by At

In short, we are given a 4D tensor y of shape ( B // s2, D2 // s1, s1, s2), where y[i,j,...] represents a matrix of shape (s1,s2). These are the block matrices used to construct the overall large matrix of shape (D2, B), and there are (B//s2) * (D2 //s1) such block matrices in total. Here we assume all the numbers involved are integers. I am clear on how to do it using for loops:

# y shape ( B // s2, D2 // s1, s1, s2)
result = torch.zeros(D2, B)
for i in range(D2 // s1):
    for j in range(B // s2):
         result[i * s1: (i + 1) * s1, j * s2: (j + 1) * s2] = y[j,i, ...]

I know the assignment can be done in parallel. Can we use pytorch built-in functions to eliminate the two for loops?

1

There are 1 best solutions below

0
Ivan On

This is called a fold operation, nn.Fold and F.fold are made for this purpose. If you look at the documentation, it reads:

Combines an array of sliding local blocks into a large containing tensor.
Consider a batched input tensor containing sliding local blocks, e.g., patches of images, of shape (N,C×∏(kernel_size),L), where:

  • N is batch dimension,
  • C×∏(kernel_size) is the number of values within a block (a block has ∏(kernel_size) spatial locations each containing a C-channeled vector),
  • L is the total number of blocks.

This is exactly the same specification as the output shape of Unfold. This operation combines these local blocks into the large output tensor of shape (N,C,output_size[0],output_size[1],…) by summing the overlapping values.

In your case, your input tensor is shaped (B//s2,D2//s1,s1,s2). To get to the specs, you have L = B//s2 * D2//s1 and kernel ∏(kernel_size) = s1 * s2. Since the function expects L in last position you will need to do some permutation before flattening:

y_ = y.permute(2,3,1,0).reshape(s1*s2,-1)

Now y_ is shaped (s1*s2, B//s2*D2//s1). Finally you can apply the fold:

F.fold(y_, output_size=(D2,B), kernel_size=(s1,s2), stride=(s1,s2))