How to do matmul on 3D and 4D matrices?

69 Views Asked by At

I have an array a whose size is torch.Size([330, 330, 36])

The original data structure is [330, 6, 330, 6]

The meaning is:

I have 330 atoms in the system, each atom has 6 orbitals

I want to know the interactions between all atoms and all orbitals.

I want to perform these operations:

(1) a.reshape(330,330,6,6).
permute(0,2,1,3).reshape(1980, 1980)

convert the matrix to (330 x 6) x (330 x 6)

(2) torch.sum(torch.diag(b@b)[1:6])

perform a matmul operation and sum the diagonal elements 1-5

I want to know if there is any method to perform matmul operation without reshaping 330x330x36 matrix.

Thanks a lot.

(1) a.reshape(330,330,6,6).permute(0,2,1,3).reshape(1980,1980)

(2) torch.sum(torch.diag(b@b)[1:6])

What if I have a list of matrices, how to do matmul operations in a single command?

1

There are 1 best solutions below

2
On

You asked for a couple of things, and what you are doing is inefficient.

Matmul without reshape

As I will explain below, you should not do this contraction. But assume you want to. You can not avoid the reshape that "splits" the axis 36 -> 6 * 6, but you can avoid combining the 6 * 303 -> 1980 by using torch.tensordot. In your case that would be

b = a.reshape(330, 330, 6, 6)
c = torch.tensordot(b, b, ([1, 3], [0, 2]))  # shape [330, 6, 330, 6]

List of matrices

If it is a list of torch.Tensors, you can not get around doint a loop of some kind, so no there is no "one command" solution. If you have a single Tensor, created e.g. via torch.tensor, say of shape as.shape == (42, 330, 330, 36) for 42 different "matrices", you can batch the torch operations;

bs = as.reshape(42, 330, 330, 6, 6)
cs = torch.tensordot(bs, bs, ([2, 4], [1, 3]))  # shape [42, 330, 6, 330, 6]

More efficient way to compute what you are after

It seems that you are only interested in a few diagonal entries of the matrix product. In your case only 5 of 1980 * 1980 total entries thats. So you should only compute those entries, as computing the other roughly 4000000 entries is not needed. For example

b = a.reshape(330, 330, 6, 6)
c = torch.sum(b[0, :, 1:5, :] * b[:, 0, :, 1:5])

should give the same as you got in your snippets above. Note that due to C-style reshaping your index 1:5 becomes 0 and 1:5, e.g.

after_reshape = before_reshape.reshape(330, 6)
before_reshape[1:5] == after_reshape[0, 1:5]