Pytorch: Compute gradient dx given Ax=b without solving the whole system

28 Views Asked by At

Given an linear system Ax=b, x = [x1; x2], where A,b are given and x1 is also given. A is square matrix. I want to compute the gradient of x1 in terms of A,b, i.e. dx1/dA, dx1/db. It seems that torch.autograd.backward() can only compute the gradient by solving the whole linear system. But I want to make it more efficient by not solving x1 again since it is already given. Is it feasible? If so, how to implement it using pytorch?

For now, I only know how to get the gradient by solving the whole system:

#[A1 A2; A3 A4] * [x1; x2] = [b1; b2] 
m1 = 10
n1 = 10
m3 = 5
n3 = n1
m2 = m1
n2 = m1 + m3 - n1
m4 = m3
n4 = n2

A = torch.rand((m1+m3,n1+n2)).clone().detach().requires_grad_(True)
A1 = A[:m1,:n1]
A2 = A[:m1,n1:]
A3 = A[m1:,:n1]
A4 = A[m1:,n1:]


# x1 is already given 
x1 = torch.ones((n1,1)).clone().detach().requires_grad_(True)
# x2 needs to be solved
x2_gt = torch.rand((n2,1))

b1 = (A1 @ x1 + A2 @ x2_gt).detach().requires_grad_(True)
b2 = (A3 @ x1 + A4 @ x2_gt).detach().requires_grad_(True)
b = torch.vstack((b1,b2))
x = torch.linalg.solve(A,b)
x.backward(torch.ones(m1+m3,1))
0

There are 0 best solutions below