What is the tensordot of this 4D einsum operation?

324 Views Asked by At

Here's a simple code that "batch multiplies" a 4D matrix a by 3D matrix b:

from functools import reduce
import numpy as np
from operator import mul

def einsum(a, b):
    return np.einsum('ijkl,jkl->ikl', a, b)

def original(a, b):
    s0, s1, s2, s3 = a.shape
    c = np.empty((s0, s2, s3))
    for j in range(s3):
        for i in range(s2):
            c[:, j, i] = np.dot(a[:, :, j, i], b[:, j, i])
    return c

sz_a = (16, 4, 512, 512)
sz_b = (4, 512, 512)

a = np.random.random(reduce(mul, sz_a)).reshape(sz_a)
b = np.random.random(reduce(mul, sz_b)).reshape(sz_b)

For timing:

%timeit original(a, b)
395 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit einsum(a, b)
23.1 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

I'd like to test out tensordot's performance to see how it compares, but I'm really having some trouble wrapping my ahead around how to use it here. If anyone is familiar enough to guide me with this, it would greatly appreciated. Thank you!

My original thought was:

np.tensordot(a, b, axes=((1),(0)))

But that gives me a MemoryError so I don't think that's right...

1

There are 1 best solutions below

0
On

Time comparisons of your einsum with a matmul equivalent:

In [910]: timeit (a.transpose(2,3,0,1)@b[:,None].transpose(2,3,0,1)).transpose(2,3,0,1)[:
     ...: ,0]
90.5 ms ± 92.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [911]: timeit np.einsum('ijkl,jkl->ikl', a, b)
92.7 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Times are close enough that I suspect einsum optimization is actually using matmul. Originally einsum used its own compiled sum-of-products iteration, but recent with recent changes it uses a variety of methods, including dot and matmul if they fit.

matmul was created to handle the case where the initial dimensions represent a stack of matrices. In your problem the last 2 dimensions are this stack, with the dot acting on the initial. matmul was created to handle this kind of stacked dots. dot, and its derivative tensordot don't handle that kind of stacking.