Use pytorch to compute hessian of network, and the value is quite different from the theory

32 Views Asked by At

I'm currently training a two-layer neural network and want to compute the hessian w.r.t. the second-layer weights of the network. However, the output of the code is quite different from the theory. Specifically, I'm training a two-layer linear network to solve the following optimization problem

$\\min\_{W_1, W_2} \\frac{1}{2}\\lVert Y-W_2W_1X\\rVert_F^2

$ where $X, Y$ are data matrices of size $\mathbb{R}^{5\times 5}$, $W_1, W_2$ is the parameters with width 100. I'm interested in the condition number of the Hessian of the loss w.r.t. $W_2$ when the iterates are around minimum. i use two ways to compute.

  1. Method one

I implement a hession computation using Pytorch. I first train the model for the relative loss to reach $10^{-20}$, then compute the hessian and it outputs the condition number to be around 174.

  1. Method two

We can derive the Hessian has an analytic form exactly at the global minimum

H=(W_1XX^\topW_1^\top)\otimes I_5

When I extract the value of $W_2$ from the model and compute the condition number of the $H$ above, the value is 5345 which is quite different.

Below is the code.

class NN_linear(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NN_linear, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
        # Initialize weights with custom numpy matrices
        np.random.seed(0)
        self.fc1.weight = nn.Parameter(torch.from_numpy(np.random.normal(0, 1, size=
                            (input_size, hidden_size)).T).double())
        self.fc2.weight = nn.Parameter(torch.from_numpy(np.random.normal(0, 1, size=
                            (hidden_size, output_size)).T).double())

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# data generation
kappa = 100
input = 5
output = 5
width = 100
N = 5

np.random.seed(123)
theta = np.random.normal(0, 1, size=(input, output)).astype(np.float64)
u = ortho_group.rvs(dim=input).astype(np.float64)
v = ortho_group.rvs(dim=output).astype(np.float64)
s = np.diag(np.linspace(start=1, stop=np.sqrt(kappa), endpoint=True, num=N))
x = u.dot(s.dot(v))
y = x.dot(theta)
dataset = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.double), 
                                         torch.tensor(y, dtype=torch.double))
dataloader = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=N, shuffle=False)
optimal_lr = 0.00023400934009340095

# training
net = NN_linear(input, width, output)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
epoch = 0
while epoch ==0 or loss_v.detach().item() > 1e-20 * loss_initial:

    for xs, ys in dataloader:
        optimizer.zero_grad()
        pred = net(xs)
        loss_v = torch.norm(pred-ys, p="fro") ** 2 / 2
        loss_v.backward()
        optimizer.step()
        epoch += 1

# Compute the gradients
input_data, Y = next(iter(dataloader))
oss_v = loss_fn(input_data, Y, net)
grads = torch.autograd.grad(loss_v, net.fc2.weight, create_graph=True)
grads = grads[1]

# Flatten and concatenate the gradients
grad_vector = torch.cat([grad.reshape(-1) for grad in grads])
num_params = grad_vector.shape[0]

# Compute the Hessian matrix
hessian_matrix = torch.zeros((num_params, num_params)).double()
for i in range(num_params):
    optimizer.zero_grad()
    grad_elem = grad_vector[i].float()
    hessian_row = torch.autograd.grad(grad_elem, net.fc2.weight, retain_graph=True)
    hessian_row = torch.cat([grad.detach().reshape(-1) for grad in hessian_row])
    hessian_matrix[i] = hessian_row

# compute condition number of Hessian
uh, sh, vh = torch.svd(torch_hessian)
r = torch.linalg.matrix_rank(torch_hessian, hermitian=True).item()
print(f"condition number is {sh[0] / sh[r-1]}_rank of hessian is: {r}") 
# condition number if 174

# compute Hessian using analytic form
w1 = net.fc1.weight.detach().numpy()
wx = w1.dot(x)
H = np.kron(wx.dot(wx.T), np.eye(output))
U, sh, Vh = np.linalg.svd(H, full_matrices=True)
r = np.linalg.matrix_rank(H)
print(f"condition number of by numpy is {sh[0] / sh[r-1]}_rank of hessian is: {r}")
# condition number is 5345

I expect the condition number given by two different approaches match, however this is not the case

0

There are 0 best solutions below