I'm currently training a two-layer neural network and want to compute the hessian w.r.t. the second-layer weights of the network. However, the output of the code is quite different from the theory. Specifically, I'm training a two-layer linear network to solve the following optimization problem
$\\min\_{W_1, W_2} \\frac{1}{2}\\lVert Y-W_2W_1X\\rVert_F^2
$ where $X, Y$ are data matrices of size $\mathbb{R}^{5\times 5}$, $W_1, W_2$ is the parameters with width 100. I'm interested in the condition number of the Hessian of the loss w.r.t. $W_2$ when the iterates are around minimum. i use two ways to compute.
- Method one
I implement a hession computation using Pytorch. I first train the model for the relative loss to reach $10^{-20}$, then compute the hessian and it outputs the condition number to be around 174.
- Method two
We can derive the Hessian has an analytic form exactly at the global minimum
H=(W_1XX^\topW_1^\top)\otimes I_5
When I extract the value of $W_2$ from the model and compute the condition number of the $H$ above, the value is 5345 which is quite different.
Below is the code.
class NN_linear(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(NN_linear, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
# Initialize weights with custom numpy matrices
np.random.seed(0)
self.fc1.weight = nn.Parameter(torch.from_numpy(np.random.normal(0, 1, size=
(input_size, hidden_size)).T).double())
self.fc2.weight = nn.Parameter(torch.from_numpy(np.random.normal(0, 1, size=
(hidden_size, output_size)).T).double())
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# data generation
kappa = 100
input = 5
output = 5
width = 100
N = 5
np.random.seed(123)
theta = np.random.normal(0, 1, size=(input, output)).astype(np.float64)
u = ortho_group.rvs(dim=input).astype(np.float64)
v = ortho_group.rvs(dim=output).astype(np.float64)
s = np.diag(np.linspace(start=1, stop=np.sqrt(kappa), endpoint=True, num=N))
x = u.dot(s.dot(v))
y = x.dot(theta)
dataset = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.double),
torch.tensor(y, dtype=torch.double))
dataloader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=N, shuffle=False)
optimal_lr = 0.00023400934009340095
# training
net = NN_linear(input, width, output)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
epoch = 0
while epoch ==0 or loss_v.detach().item() > 1e-20 * loss_initial:
for xs, ys in dataloader:
optimizer.zero_grad()
pred = net(xs)
loss_v = torch.norm(pred-ys, p="fro") ** 2 / 2
loss_v.backward()
optimizer.step()
epoch += 1
# Compute the gradients
input_data, Y = next(iter(dataloader))
oss_v = loss_fn(input_data, Y, net)
grads = torch.autograd.grad(loss_v, net.fc2.weight, create_graph=True)
grads = grads[1]
# Flatten and concatenate the gradients
grad_vector = torch.cat([grad.reshape(-1) for grad in grads])
num_params = grad_vector.shape[0]
# Compute the Hessian matrix
hessian_matrix = torch.zeros((num_params, num_params)).double()
for i in range(num_params):
optimizer.zero_grad()
grad_elem = grad_vector[i].float()
hessian_row = torch.autograd.grad(grad_elem, net.fc2.weight, retain_graph=True)
hessian_row = torch.cat([grad.detach().reshape(-1) for grad in hessian_row])
hessian_matrix[i] = hessian_row
# compute condition number of Hessian
uh, sh, vh = torch.svd(torch_hessian)
r = torch.linalg.matrix_rank(torch_hessian, hermitian=True).item()
print(f"condition number is {sh[0] / sh[r-1]}_rank of hessian is: {r}")
# condition number if 174
# compute Hessian using analytic form
w1 = net.fc1.weight.detach().numpy()
wx = w1.dot(x)
H = np.kron(wx.dot(wx.T), np.eye(output))
U, sh, Vh = np.linalg.svd(H, full_matrices=True)
r = np.linalg.matrix_rank(H)
print(f"condition number of by numpy is {sh[0] / sh[r-1]}_rank of hessian is: {r}")
# condition number is 5345
I expect the condition number given by two different approaches match, however this is not the case