In short, we are given a 4D tensor y of shape ( B // s2, D2 // s1, s1, s2), where y[i,j,...] represents a matrix of shape (s1,s2). These are the block matrices used to construct the overall large matrix of shape (D2, B), and there are (B//s2) * (D2 //s1) such block matrices in total. Here we assume all the numbers involved are integers. I am clear on how to do it using for loops:
# y shape ( B // s2, D2 // s1, s1, s2)
result = torch.zeros(D2, B)
for i in range(D2 // s1):
for j in range(B // s2):
result[i * s1: (i + 1) * s1, j * s2: (j + 1) * s2] = y[j,i, ...]
I know the assignment can be done in parallel. Can we use pytorch built-in functions to eliminate the two for loops?
This is called a fold operation,
nn.FoldandF.foldare made for this purpose. If you look at the documentation, it reads:In your case, your input tensor is shaped
(B//s2,D2//s1,s1,s2). To get to the specs, you haveL = B//s2 * D2//s1and kernel∏(kernel_size) = s1 * s2. Since the function expectsLin last position you will need to do some permutation before flattening:Now
y_is shaped(s1*s2, B//s2*D2//s1). Finally you can apply the fold: