How to include model's parameter in my custom loss function

67 Views Asked by At

I am using PyTorch Lightning and I defined my model like below:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)

        self.model_parameter = nn.Parameter(
            torch.rand(...) 
        )

And I use a custom loss function like below:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets):
        loss = ...
        scalar_loss = torch.mean(loss)
        return scalar_loss

And in my config file, I set the class_path like below:

model:
    class_path: ...path_to_MyModel
    init_args:
    criterion:
        class_path: ...path_to_MyCustomLoss

However, I need a way to access my model_parameter in my custom loss function. I need these parameters to calculate my loss. How can I have my model's parameters in my custom loss function?

1

There are 1 best solutions below

0
On BEST ANSWER

You need to pass the model instance to your custom loss function. Once your custom loss function has the access to the model instance, you can pull the model_parameters. Here's how it works:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets, model):
        # access model parameters
        model_parameter = model.model_parameter
        # proceed to calculate loss...
    

And when you call the loss function in your training step, pass the model instance:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)
        self.model_parameter = nn.Parameter(torch.rand(...))
        self.criterion = MyCustomLoss(...)  # Your custom loss function

    def forward(self, x):
        # define the forward pass
        ...

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        # compute loss
        loss = self.criterion(outputs, targets, self)
        return loss