How to modify parameters that require gradients in meta learning?

485 Views Asked by At

I have a neural network that is trained to output learning rates:

import torch 
import torch.nn as nn
import torch.optim as optim

criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Meta_Model(nn.Module):
    def __init__(self):
        super(Meta_Model, self).__init__()

        self.fc1 = nn.Linear(1,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,32)
        self.fc4 = nn.Linear(32,1)

        self.lky = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = self.lky(self.fc1(x))
        x = self.lky(self.fc2(x))
        x = self.lky(self.fc3(x))
        x = self.fc4(x)
        return x # x should be some learning rate

meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)

I have some inputs and a function I'm trying to learn:

input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn

I'm trying to update one trainable parameter to solve this function:

meta_model_epochs = 10
w_epochs = 5

for _ in range(meta_model_epochs):
    torch.manual_seed(42) # reset seed for reproducibility
    w1 = torch.rand(1, requires_grad=True) # reset **trainable weight**
    weight_opt = optim.SGD([w1], lr=1e-1) # reset weight optimizer
    meta_loss = 0 # reset meta loss
    for _ in range(w_epochs):
        predicted_tensor = w1 * input_tensor 
        loss = criterion(predicted_tensor, label_tensor)
        meta_loss += loss # add to meta loss
        meta_model_output = meta_model(loss.detach().unsqueeze(0)) # input to the meta model is the loss
        weight_opt.zero_grad()
        loss.backward(retain_graph=True) # get grads

        w1 = w1 - meta_model_output * w1.grad # step --> this is the issue
    
    meta_model_opt.zero_grad()
    meta_loss.backward()
    meta_model_opt.step()
    print('meta_loss', meta_loss.item())

So the setting is that the meta model should learn to output the optimal learning rate to update the trainable parameter w1 based on the current loss.

The issue is that I'm getting RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 2; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I also tried replacing the update step with w1.data = w1.data - meta_model_output * w1.grad # step, which resolves the issue, but then the meta model is not updating (i.e., the loss stays the same)

Update 1:

Tried @VonC idea of computing the updated value of w1 (using a clone of w1: w1_updated_value) and setting it as the data of w1:

w1_clone = w1.clone()
w1_clone = w1_clone - meta_model_output * w1.grad # step
w1.data = w1_clone

While this removes the error, it results in the same issue of the meta model is not updating (i.e., the loss stays the same).

Update 2: After lots of readings on buffers and updating leaf tensors I got a solution that updates a whole network:

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import functools 
import math 


criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Meta_Model(nn.Module):
    def __init__(self):
        super(Meta_Model, self).__init__()

        self.fc1 = nn.Linear(1,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,32)
        self.fc4 = nn.Linear(32,1)

        self.lky = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = self.lky(self.fc1(x))
        x = self.lky(self.fc2(x))
        x = self.lky(self.fc3(x))
        x = self.fc4(x)
        return x # x should be some learning rate

meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
meta_model_epochs = 10
w_epochs = 5

############## the new solution
def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

class MetaLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        # Initialize weights and biases to zero
        # The line below is nearly identical to "self.weight = ...", but we get all of the added PyTorch features.
        self.register_buffer('weight', torch.zeros(out_features, in_features, requires_grad=True))
        if bias:
            self.register_buffer('bias', torch.zeros(out_features, requires_grad=True))
        else:
            self.bias = None
        
        # Fancy initialization from https://discuss.pytorch.org/t/how-are-layer-weights-and-biases-initialized-by-default/13073
        stdv = 2. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
            
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)    

class Weights_network(nn.Module):
    def __init__(self):
        super(Weights_network, self).__init__()

        self.fc1 = MetaLinear(1,8, bias=True)
        self.fc2 = MetaLinear(8,8, bias=True)
        self.fc3 = MetaLinear(8,8, bias=True)
        self.fc4 = MetaLinear(8,1, bias=True)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        return x

class CustomeOptimizer():
    def __init__(self, model):
        self.named_buffers = model.named_buffers()
        self.model = model

    def zero_grad(self):
        for name, param in self.named_buffers:
            if param.grad:
                param.grad.zero_()

    def step(self, meta_output_lr): 
        for name, param in self.model.named_buffers():
            clipping_value = 1e-2
            clipped_gradient = torch.clip(param.grad.detach().clone(), min = -clipping_value, max = clipping_value)
            new_param = (param.clone() - meta_output_lr.to(device) * clipped_gradient)            
            new_param.retain_grad()
            rsetattr(self.model, name, new_param)        
##############

for _ in range(meta_model_epochs): # meta training loop
    meta_loss = 0 # meta loss
    torch.manual_seed(42)
    w1 = Weights_network().to(device) # reset network
    # maintain grad
    for name, param in w1.named_buffers():
        param.retain_grad()     
    weight_opt = CustomeOptimizer(w1) # reset optimizer
    for _ in range(w_epochs): # weights training loop
        predicted_tensor = w1(input_tensor)
        loss = criterion(predicted_tensor, label_tensor)
        meta_loss += loss
        meta_model_output = meta_model(loss.detach().unsqueeze(0)).to(device)
        weight_opt.zero_grad()
        loss.backward(retain_graph=True)
        weight_opt.step(meta_model_output)

    meta_model_opt.zero_grad()
    meta_loss.backward()
    meta_model_opt.step()
    print('meta_loss', meta_loss.item())
>>> 
meta_loss 0.9203591346740723
meta_loss 0.4630056917667389
meta_loss 5.195590972900391
meta_loss 0.44623494148254395
meta_loss 0.45180386304855347
meta_loss 0.5693209767341614
meta_loss 0.43741166591644287
meta_loss 0.5331400632858276
meta_loss 0.5698808431625366
meta_loss 0.4637502133846283

Though, 1) I'm not sure why this works as the update step seems similar to what I tried. 2) Since I don't understand why this works I can't figure out how to use this on my initial setup where I have a simple parameter w1 that is not a whole network

Update 3: Tried @VonC suggestion on wrapping my tensor of weights as a buffer:

import torch 
import torch.nn as nn
import torch.optim as optim
import functools 

criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Meta_Model(nn.Module):
    def __init__(self):
        super(Meta_Model, self).__init__()

        self.fc1 = nn.Linear(1,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,32)
        self.fc4 = nn.Linear(32,1)

        self.lky = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = self.lky(self.fc1(x))
        x = self.lky(self.fc2(x))
        x = self.lky(self.fc3(x))
        x = self.fc4(x)
        return x # x should be some learning rate

class Weights(nn.Module):
    def __init__(self):
        super(Weights, self).__init__()
        self.register_buffer('w1', torch.rand(1, requires_grad=True))

meta_model = Meta_Model().to(device)
meta_model_opt = optim.Adam(meta_model.parameters(), lr=1e-1)
input_tensor = torch.rand(1000,1) # some inputs
label_tensor = 2 * input_tensor # function to learn
meta_model_epochs = 10
w_epochs = 5

for _ in range(meta_model_epochs):
    torch.manual_seed(42)
    w1 = Weights().w1
    weight_opt = optim.SGD([w1], lr=0)  # Set learning rate to zero
    meta_loss = 0
    for _ in range(w_epochs):
        predicted_tensor = w1 * input_tensor
        loss = criterion(predicted_tensor, label_tensor)
        meta_loss += loss
        meta_model_output = meta_model(loss.detach().unsqueeze(0))
        weight_opt.zero_grad()
        loss.backward(retain_graph=True)
        # Compute updated value of w1
        w1_updated = w1.clone() - meta_model_output * w1.grad
        w1_updated.retain_grad()
        # Use w1_updated as input to the next iteration
        w1 = w1_updated

    meta_model_opt.zero_grad()
    meta_loss.backward()
    meta_model_opt.step()
    print('meta_loss', meta_loss.item())

But I'm still getting the inplace operation runtime error.

0

There are 0 best solutions below