I want to log my training and validation metrics on the same plot as two lines of different colors. I used to do that using torch lightning 1.7.7 with the following code:
self.log_dict({'output_1 Loss': {'VALIDATION': 0}, 'output_2 Loss': {'VALIDATION': 0}})
After updating to 2.2.0+post0, I get the error:
ValueError: `self.log(output_1 Loss, {'VALIDATION': 0})` was called, but `dict` values cannot be logged
How should I edit my code?
Here's an example that you can use to reproduce the error:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.linear = nn.Linear(10, 1)
self.linear
def forward(self, x):
return self.linear(x)
class DummyModelModule(pl.LightningModule):
def __init__(self, net: DummyNet):
super(DummyModelModule, self).__init__()
self.net = net
self.loss = torch.nn.MSELoss()
self.optimizer = torch.optim.SGD(net.parameters(), lr=1e-8, momentum=0.9)
def forward(self, x: torch.Tensor):
return self.net(x)
def configure_optimizers(self):
return self.optimizer
def training_step(self, batch, batch_id):
inputs, y = batch
y_pred = self.forward(inputs)
loss = self.loss(y_pred, y)
self.log('loss', loss, prog_bar=True)
# let's assume that the model had two outputs
self.log_dict({
'output_1 Loss': {'TRAINING': loss},
'output_2 Loss': {'TRAINING': loss * 2}}
)
return {'loss': loss}
def validation_step(self, batch, batch_id):
inputs, y = batch
y_pred = self.forward(inputs)
loss = self.loss(y_pred, y)
# let's assume that the model had two outputs
self.log_dict({
'output_1 Loss': {'VALIDATION': loss},
'output_2 Loss': {'VALIDATION': loss * 2}}
)
return {'loss': loss}
def test_step(self, batch, batch_id):
inputs, y = batch
y_pred = self.forward(inputs)
loss = self.loss(y_pred, y)
return {'loss': loss}
class DummyDataset(Dataset):
def __init__(self):
self.items = [
(
torch.ones((1, 10), dtype=torch.float32) * i,
torch.tensor([2 * i - 5], dtype=torch.float32) + torch.randint(-10, 10, (1, ))
)
for i in range(100)
]
def __len__(self):
return len(self.items)
def __getitem__(self, index: int):
return self.items[index]
class DummyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.save_hyperparameters(logger=False)
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
def prepare_data(self):
pass
def setup(self, stage = None):
self.train_dataset = DummyDataset()
self.val_dataset = DummyDataset()
self.test_dataset = DummyDataset()
@property
def _get_data_loaders_common_kwargs(self):
return dict(
batch_size=8,
num_workers=1,
pin_memory=True,
persistent_workers=True,
prefetch_factor=1,
drop_last=False
)
def train_dataloader(self):
return DataLoader(self.train_dataset)
def val_dataloader(self):
return DataLoader(self.val_dataset)
def test_dataloader(self):
return DataLoader(self.test_dataset)
if __name__ == '__main__':
torch.multiprocessing.set_start_method("fork")
net = DummyNet()
model = DummyModelModule(net)
datamodule = DummyDataModule()
datamodule.setup()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, datamodule=datamodule)
Thank you!