How to log training and validation on the same plot in torch lightning 2.2.0

30 Views Asked by At

I want to log my training and validation metrics on the same plot as two lines of different colors. I used to do that using torch lightning 1.7.7 with the following code:

self.log_dict({'output_1 Loss': {'VALIDATION': 0}, 'output_2 Loss': {'VALIDATION': 0}})

After updating to 2.2.0+post0, I get the error:

ValueError: `self.log(output_1 Loss, {'VALIDATION': 0})` was called, but `dict` values cannot be logged

How should I edit my code?

Here's an example that you can use to reproduce the error:

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class DummyNet(nn.Module):

    def __init__(self):
        super(DummyNet, self).__init__()
        self.linear = nn.Linear(10, 1)
        self.linear

    def forward(self, x):
        return self.linear(x)


class DummyModelModule(pl.LightningModule):

    def __init__(self, net: DummyNet):
        super(DummyModelModule, self).__init__()
        self.net = net
        self.loss = torch.nn.MSELoss()
        self.optimizer = torch.optim.SGD(net.parameters(), lr=1e-8, momentum=0.9)

    def forward(self, x: torch.Tensor):
        return self.net(x)

    def configure_optimizers(self):
        return self.optimizer

    def training_step(self, batch, batch_id):
        inputs, y = batch
        y_pred = self.forward(inputs)
        loss = self.loss(y_pred, y)
        self.log('loss', loss, prog_bar=True)
        # let's assume that the model had two outputs
        self.log_dict({
            'output_1 Loss': {'TRAINING': loss},
            'output_2 Loss': {'TRAINING': loss * 2}}
        )
        return {'loss': loss}

    def validation_step(self, batch, batch_id):
        inputs, y = batch
        y_pred = self.forward(inputs)
        loss = self.loss(y_pred, y)
        # let's assume that the model had two outputs
        self.log_dict({
            'output_1 Loss': {'VALIDATION': loss},
            'output_2 Loss': {'VALIDATION': loss * 2}}
        )
        return {'loss': loss}

    def test_step(self, batch, batch_id):
        inputs, y = batch
        y_pred = self.forward(inputs)
        loss = self.loss(y_pred, y)
        return {'loss': loss}


class DummyDataset(Dataset):

    def __init__(self):
        self.items = [
            (
                torch.ones((1, 10), dtype=torch.float32) * i,
                torch.tensor([2 * i - 5], dtype=torch.float32) + torch.randint(-10, 10, (1, ))
            )
            for i in range(100)
        ]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index: int):
        return self.items[index]


class DummyDataModule(pl.LightningDataModule):

    def __init__(self):
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        pass

    def setup(self, stage = None):
        self.train_dataset = DummyDataset()
        self.val_dataset = DummyDataset()
        self.test_dataset = DummyDataset()

    @property
    def _get_data_loaders_common_kwargs(self):
        return dict(
            batch_size=8,
            num_workers=1,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=1,
            drop_last=False
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset)

    def val_dataloader(self):
        return DataLoader(self.val_dataset)

    def test_dataloader(self):
        return DataLoader(self.test_dataset)


if __name__ == '__main__':
    torch.multiprocessing.set_start_method("fork")

    net = DummyNet()
    model = DummyModelModule(net)
    datamodule = DummyDataModule()
    datamodule.setup()
    trainer = pl.Trainer(max_epochs=5)
    trainer.fit(model, datamodule=datamodule)

Thank you!

0

There are 0 best solutions below