Speeding up sparse inversion, element-wise multiplication and addition in python

41 Views Asked by At

I am trying to invert a 10000x10000 sparse matrices that are tri-diagonal, but the problem I have is that scipy.sparse.linalg.inv() is too slow, so I have tried to use spsolve() instead. The main issue is that whenever I perform element-wise multiplication with *, the sparse matrices change shape. My original code is:

mport numpy as np
from scipy import sparse


def generate_tri_diagonal_matrix(size):

    main_diagonal = np.random.rand(size)
    upper_diagonal = np.random.rand(size - 1)
    lower_diagonal = np.random.rand(size - 1)


    tri_diagonal_matrix = np.diag(main_diagonal) + np.diag(upper_diagonal, k=1) + np.diag(lower_diagonal, k=-1)

    return sparse.csc_matrix(tri_diagonal_matrix)


M=10000
matrix_size = M

c = 1.53    # Some random float


A = generate_tri_diagonal_matrix(matrix_size)
A_0 = A*2
A_1 = A*3
A_2 = A*4

P = np.random.rand(M)
I = np.eye(M)


inv_ini_1 = sparse.csc_matrix(I - c*A_1)
inv_1 = sparse.linalg.inv(inv_ini_1)

inv_ini_2 = sparse.csc_matrix(I - c*A_2)
inv_2 = sparse.linalg.inv(inv_ini_2)

Y_0 = P + delta_t*A*P
Y_1 = inv_1 * (Y_0 - c*A_1*P)
Y_2 = inv_2 * (Y_1 - c*A_2*P)

P = Y_2

But when I try to use:

# Attempt

inv_1 = sparse.csc_matrix(I - c*A_1)
inv_2 = sparse.csc_matrix(I - c*A_2)

Y_0 = sparse.csc_matrix(P + delta_t*A*P)
Y_1 = sparse.linalg.spsolve(inv_1 , Y_0 - c*A_1*P)
Y_2 = sparse.linalg.spsolve(inv_2 , Y_1 - c*A_2*P)

P = Y_2

The Y_0 - c*A_1*P in Y_1 changes from a (10k,10k) matrix to a (10k,) matrix. Why does this happen and how do I speed up this inversion? Because each linalg.inv() takes my computer 15 seconds.

1

There are 1 best solutions below

0
ev-br On

First and foremost, drom the numerical stablity POV, solving a linear system is almost always preferable to a direct inversion.

Second, if you know your matrices are tridiagonal, you may want to take a look at the _banded and _tridiagiaginal families of functions from scipy.linalg. or even bare lapack wrappers from scipy.linalg.lapack, such as https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.lapack.dgtsv.html