I have an array a whose size is torch.Size([330, 330, 36])
The original data structure is [330, 6, 330, 6]
The meaning is:
I have 330 atoms in the system, each atom has 6 orbitals
I want to know the interactions between all atoms and all orbitals.
I want to perform these operations:
(1) a.reshape(330,330,6,6).
permute(0,2,1,3).reshape(1980, 1980)
convert the matrix to (330 x 6) x (330 x 6)
(2) torch.sum(torch.diag(b@b)[1:6])
perform a matmul operation and sum the diagonal elements 1-5
I want to know if there is any method to perform matmul operation without reshaping 330x330x36 matrix.
Thanks a lot.
(1) a.reshape(330,330,6,6).permute(0,2,1,3).reshape(1980,1980)
(2) torch.sum(torch.diag(b@b)[1:6])
What if I have a list of matrices, how to do matmul operations in a single command?
You asked for a couple of things, and what you are doing is inefficient.
Matmul without reshape
As I will explain below, you should not do this contraction. But assume you want to. You can not avoid the reshape that "splits" the axis
36 -> 6 * 6
, but you can avoid combining the6 * 303 -> 1980
by usingtorch.tensordot
. In your case that would beList of matrices
If it is a list of
torch.Tensor
s, you can not get around doint a loop of some kind, so no there is no "one command" solution. If you have a singleTensor
, created e.g. viatorch.tensor
, say of shapeas.shape == (42, 330, 330, 36)
for42
different "matrices", you can batch the torch operations;More efficient way to compute what you are after
It seems that you are only interested in a few diagonal entries of the matrix product. In your case only
5
of1980 * 1980
total entries thats. So you should only compute those entries, as computing the other roughly4000000
entries is not needed. For exampleshould give the same as you got in your snippets above. Note that due to C-style reshaping your index
1:5
becomes0
and1:5
, e.g.