How to train deeplabv3 on custom dataset on pytorch?

4.9k Views Asked by At

I am trying to do image segmentation and I got to know the Google work of DeepLabv3.

This is the reference to the paper: https://arxiv.org/abs/1706.05587

Chen, L.C., Papandreou, G., Schroff, F. and Adam, H., 2017. Rethinking atrous convolution for semantic image segmentation. arXiv preprint arXiv:1706.05587.

This architecture is trained to do segmentation of the 20+1 classes of the Pascal VOC 2012 Dataset (20 foreground and 1 background class).

Pytorch provides pre-trained deeplabv3 on Pascal dataset, I would like to train the same architecture on cityscapes. Therefore, there are different classes with respect to the Pascal VOC dataset. I would like to know what is the efficient way to do it?

For now this is the only code I wrote:

import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()
1

There are 1 best solutions below

0
On
  • Write custom Dataloader class which should inherit Dataset class and implement at least 2 methods __len__ and __getitem__.
  • Modify the pretrained DeeplabV3 head with your custom number of output channels.
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101

def custom_DeepLabv3(out_channel):
  model = deeplabv3_resnet101(pretrained=True, progress=True)
  model.classifier = DeepLabHead(2048, out_channel)

  #Set the model in training mode
  model.train()
  return model
  • Train and evaluate the model.