It seems that torch.optim.SGD may have a bug when momentum is added. From my understanding, one can implement SGD with momentum by simply providing some value for the momentum argument, such as
torch.optim.SGD(params, lr=0.01, momentum=0.9)
I suspect a potential bug because I try to replicate the pytorch lightning tutorial regarding optimizer here. Rather than implementing optimizers from scratch as in the tutorial, I used the function from torch.optim directly. In particular, in [32], I replaced
SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))
by
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
The result appears to be much worse than indicated here. The result for Nesterov accelerated gradient as implemented below appears to be worse than one may expect as well.
NAG_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9, nesterov=True))
Maybe my understanding was incorrect. But I can't really spot anything suspicious and I really appreciate someone else can check if they get the same discrepancy.
As requested in the comment, I have added the code below. Please note that nothing has changed except the few lines to compute SGD_points, SGDMom_points, and Adam_points
from matplotlib import cm
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import numpy as np
def pathological_curve_loss(w1, w2):
# Example of a pathological curvature. There are many more possible, feel free to experiment here!
x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
x2_loss = torch.sigmoid(w2)
return x1_loss + x2_loss
def plot_curve(
curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
):
fig = plt.figure()
ax = fig.add_subplot(projection='3d') if plot_3d else fig.gca()
# ax = fig.gca(projection="3d") if plot_3d else fig.gca()
x = torch.arange(x_range[0], x_range[1], (x_range[1] - x_range[0]) / 100.0)
y = torch.arange(y_range[0], y_range[1], (y_range[1] - y_range[0]) / 100.0)
x, y = torch.meshgrid([x, y])
z = curve_fn(x, y)
x, y, z = x.numpy(), y.numpy(), z.numpy()
if plot_3d:
ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
ax.set_zlabel("loss")
else:
ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range[0], x_range[1], y_range[0], y_range[1]))
plt.title(title)
ax.set_xlabel(r"$w_1$")
ax.set_ylabel(r"$w_2$")
plt.tight_layout()
return ax
# sns.reset_orig()
# _ = plot_curve(pathological_curve_loss, plot_3d=True)
# plt.show()
from torch import nn
def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
"""
Args:
optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
curve_func: Loss function (e.g. pathological curvature)
num_updates: Number of updates/steps to take when optimizing
init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
Returns:
Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
"""
weights = nn.Parameter(torch.FloatTensor(init), requires_grad=True)
optim = optimizer_func([weights])
list_points = []
for _ in range(num_updates):
loss = curve_func(weights[0], weights[1])
list_points.append(torch.cat([weights.data.detach(), loss.unsqueeze(dim=0).detach()], dim=0))
optim.zero_grad()
loss.backward()
optim.step()
points = torch.stack(list_points, dim=0).numpy()
return points
# BEGIN only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
SGD_points = train_curve(lambda params: torch.optim.SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
Adam_points = train_curve(lambda params: torch.optim.Adam(params, lr=1))
# END only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
pathological_curve_loss,
x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
plot_3d=False,
)
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=3, label="Adam")
plt.legend()
plt.show()
Okay, it turns out that PyTorch exactly implemented as described in the Note here. If I modified the tutorial presentation to
The two implementations now match precisely.
Or if we keep the original SGDMomentum implementation, the LR of the pytorch implementation is actually 1/(1-self.momentum) larger than the tutorial presentation. So if we change
to
It matches the curve in the tutorial precisely as well.