What kinds of optimization are used in PyTorch methods?

229 Views Asked by At

I'm using PyTorch to implement an intense sequence of matrix operations, using methods such as torch.mm or torch.dot. I was wondering if PyTorch uses multithreading or other optimization mechanisms to speed up the process. I am not utilizing a GPU. I appreciate if you could inform me of how fast these methods are and whether I need to take any actions to help the process.

1

There are 1 best solutions below

0
On BEST ANSWER

PyTorch uses an efficient BLAS implementation and multithreading (openMP, if I'm not wrong) to parallelize such operations with multiple cores. Some performance loss comes from the Python itself - since this is an interpreted language, no significant compiler-like optimization can be done. You can use the jit module to speed up the "wrapper" code around the matrix multiplies, but for anything more than very small matrices this cost is probably negligible.

One big improvement you may be able to get manually, but which PyTorch doesn't apply automatically, is to properly order the matrix multiplies. As you probably know, depending on matrix shapes, a multiplication ABCD may have different performance computed as A(B(CD)) than if computed as (AB)(CD), etc.