the first calculation with torch.einsum is much slower

100 Views Asked by At

When I run several calculations with torch.einsum in a row, the first one is always much slower than the following calculations.

The following code and plot illustrates the problem:

import torch as tor
from timeit import default_timer as timer

N = 1000
L = 10

time_arr = np.zeros(L)

for i in range(L):
    a = tor.randn(N, N).to("cuda:0") #3 random 1000x1000 matrices for each cycle
    b = tor.randn(N, N).to("cuda:0")
    c = tor.randn(N, N).to("cuda:0")

    time_start  = timer()
    tor.einsum("ij, kj",tor.einsum("ij, ja",  aa, ab), ac)
    time_end  = timer()

    time_arr[i] = time_end - time_start

Plot of the different times for each cylce of the loop

0

There are 0 best solutions below