I have a neural network that is trained to output learning rates:
import torch
import torch.nn as nn
import torch.optim as optim
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Meta_Model(nn.Module):
def __init__(self):
super(Meta_Model, self).__init__()
self.fc1 = nn.Linear(1,32)
self.fc2 = nn.Linear(32,32)
self.fc3 = nn.Linear(32,32)
self.fc4 = nn.Linear(32,1)
self.lky = nn.LeakyReLU(0.1)
def forward(self, x):
x = self.lky(self.fc1(x))
x = self.lky(self.fc2(x))
x = self.lky(self.fc3(x))
x = self.fc4(x)
return x # x should be some learning rate
meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
I have some inputs and a function I'm trying to learn:
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
I'm trying to update one trainable parameter to solve this function:
meta_model_epochs = 10
w_epochs = 5
for _ in range(meta_model_epochs):
torch.manual_seed(42) # reset seed for reproducibility
w1 = torch.rand(1, requires_grad=True) # reset **trainable weight**
weight_opt = optim.SGD([w1], lr=1e-1) # reset weight optimizer
meta_loss = 0 # reset meta loss
for _ in range(w_epochs):
predicted_tensor = w1 * input_tensor
loss = criterion(predicted_tensor, label_tensor)
meta_loss += loss # add to meta loss
meta_model_output = meta_model(loss.detach().unsqueeze(0)) # input to the meta model is the loss
weight_opt.zero_grad()
loss.backward(retain_graph=True) # get grads
w1 = w1 - meta_model_output * w1.grad # step --> this is the issue
meta_model_opt.zero_grad()
meta_loss.backward()
meta_model_opt.step()
print('meta_loss', meta_loss.item())
So the setting is that the meta model should learn to output the optimal learning rate to update the trainable parameter w1
based on the current loss.
The issue is that I'm getting RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 2; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I also tried replacing the update step with w1.data = w1.data - meta_model_output * w1.grad # step
, which resolves the issue, but then the meta model is not updating (i.e., the loss stays the same)
Update 1:
Tried @VonC idea of computing the updated value of w1 (using a clone of w1: w1_updated_value) and setting it as the data of w1:
w1_clone = w1.clone()
w1_clone = w1_clone - meta_model_output * w1.grad # step
w1.data = w1_clone
While this removes the error, it results in the same issue of the meta model is not updating (i.e., the loss stays the same).
Update 2: After lots of readings on buffers and updating leaf tensors I got a solution that updates a whole network:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import functools
import math
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Meta_Model(nn.Module):
def __init__(self):
super(Meta_Model, self).__init__()
self.fc1 = nn.Linear(1,32)
self.fc2 = nn.Linear(32,32)
self.fc3 = nn.Linear(32,32)
self.fc4 = nn.Linear(32,1)
self.lky = nn.LeakyReLU(0.1)
def forward(self, x):
x = self.lky(self.fc1(x))
x = self.lky(self.fc2(x))
x = self.lky(self.fc3(x))
x = self.fc4(x)
return x # x should be some learning rate
meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
meta_model_epochs = 10
w_epochs = 5
############## the new solution
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
class MetaLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
# Initialize weights and biases to zero
# The line below is nearly identical to "self.weight = ...", but we get all of the added PyTorch features.
self.register_buffer('weight', torch.zeros(out_features, in_features, requires_grad=True))
if bias:
self.register_buffer('bias', torch.zeros(out_features, requires_grad=True))
else:
self.bias = None
# Fancy initialization from https://discuss.pytorch.org/t/how-are-layer-weights-and-biases-initialized-by-default/13073
stdv = 2. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x):
return F.linear(x, self.weight, self.bias)
class Weights_network(nn.Module):
def __init__(self):
super(Weights_network, self).__init__()
self.fc1 = MetaLinear(1,8, bias=True)
self.fc2 = MetaLinear(8,8, bias=True)
self.fc3 = MetaLinear(8,8, bias=True)
self.fc4 = MetaLinear(8,1, bias=True)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = torch.relu(self.fc4(x))
return x
class CustomeOptimizer():
def __init__(self, model):
self.named_buffers = model.named_buffers()
self.model = model
def zero_grad(self):
for name, param in self.named_buffers:
if param.grad:
param.grad.zero_()
def step(self, meta_output_lr):
for name, param in self.model.named_buffers():
clipping_value = 1e-2
clipped_gradient = torch.clip(param.grad.detach().clone(), min = -clipping_value, max = clipping_value)
new_param = (param.clone() - meta_output_lr.to(device) * clipped_gradient)
new_param.retain_grad()
rsetattr(self.model, name, new_param)
##############
for _ in range(meta_model_epochs): # meta training loop
meta_loss = 0 # meta loss
torch.manual_seed(42)
w1 = Weights_network().to(device) # reset network
# maintain grad
for name, param in w1.named_buffers():
param.retain_grad()
weight_opt = CustomeOptimizer(w1) # reset optimizer
for _ in range(w_epochs): # weights training loop
predicted_tensor = w1(input_tensor)
loss = criterion(predicted_tensor, label_tensor)
meta_loss += loss
meta_model_output = meta_model(loss.detach().unsqueeze(0)).to(device)
weight_opt.zero_grad()
loss.backward(retain_graph=True)
weight_opt.step(meta_model_output)
meta_model_opt.zero_grad()
meta_loss.backward()
meta_model_opt.step()
print('meta_loss', meta_loss.item())
>>>
meta_loss 0.9203591346740723
meta_loss 0.4630056917667389
meta_loss 5.195590972900391
meta_loss 0.44623494148254395
meta_loss 0.45180386304855347
meta_loss 0.5693209767341614
meta_loss 0.43741166591644287
meta_loss 0.5331400632858276
meta_loss 0.5698808431625366
meta_loss 0.4637502133846283
Though, 1) I'm not sure why this works as the update step seems similar to what I tried. 2) Since I don't understand why this works I can't figure out how to use this on my initial setup where I have a simple parameter w1
that is not a whole network
Update 3: Tried @VonC suggestion on wrapping my tensor of weights as a buffer:
import torch
import torch.nn as nn
import torch.optim as optim
import functools
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Meta_Model(nn.Module):
def __init__(self):
super(Meta_Model, self).__init__()
self.fc1 = nn.Linear(1,32)
self.fc2 = nn.Linear(32,32)
self.fc3 = nn.Linear(32,32)
self.fc4 = nn.Linear(32,1)
self.lky = nn.LeakyReLU(0.1)
def forward(self, x):
x = self.lky(self.fc1(x))
x = self.lky(self.fc2(x))
x = self.lky(self.fc3(x))
x = self.fc4(x)
return x # x should be some learning rate
class Weights(nn.Module):
def __init__(self):
super(Weights, self).__init__()
self.register_buffer('w1', torch.rand(1, requires_grad=True))
meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
meta_model_epochs = 10
w_epochs = 5
for _ in range(meta_model_epochs):
torch.manual_seed(42)
w1 = Weights().w1
weight_opt = optim.SGD([w1], lr=0) # Set learning rate to zero
meta_loss = 0
for _ in range(w_epochs):
predicted_tensor = w1 * input_tensor
loss = criterion(predicted_tensor, label_tensor)
meta_loss += loss
meta_model_output = meta_model(loss.detach().unsqueeze(0))
weight_opt.zero_grad()
loss.backward(retain_graph=True)
# Compute updated value of w1
w1_updated = w1.clone() - meta_model_output * w1.grad
w1_updated.retain_grad()
# Use w1_updated as input to the next iteration
w1 = w1_updated
meta_model_opt.zero_grad()
meta_loss.backward()
meta_model_opt.step()
print('meta_loss', meta_loss.item())
But I'm still getting the inplace operation runtime error.