I have a tensor of shape (m*n, m*n) and I want to extract a tensor of size (n, m*n) containing the m blocks of size n*n that are on the diagonal. For example:
>>> a
tensor([[1, 2, 0, 0],
[3, 4, 0, 0],
[0, 0, 5, 6],
[0, 0, 7, 8]])
I want to have a function extract(a, m, n)
that will output:
>>> extract(a, 2, 2)
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
I've thought of using some kind of slicing, because the blocks can be expressed by:
>>> for i in range(m):
... print(a[i*m: i*m + n, i*m: i*m + n])
tensor([[1, 2],
[3, 4]])
tensor([[5, 6],
[7, 8]])
You can take advantage of
reshape
and slicing:Example: