The DeepLabV3+(with PyTorch) model using the Xception backbone was implemented by referring to the paper as much as possible, but the performance of the training using the VOC2012 dataset is too low.
One-hot-processed datasets for SegmentationClass of datasets were used for learning, and the following aspects occurred when the loss and miou graph were checked using WandB:
- It shows a lower value of validation loss than training, and it shows a flat pattern rather than a decreasing pattern from the first epoch until it stops.
- miou is similarly flat, and verification miou hardly changes.
The optimizer applied AdamW, and the scheduler used ReduceLROnPlateau (mode='max') for miou.
I've tried a lot of other codes, but the results were always the same, so I couldn't find an answer anymore, so I left a question here.
I'll write down the model code here, so please try it and let me know if you've found the answer.
Separble convolution:
class SeparableConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3,
stride=1, padding=0, dilation=1, bias=True, depthwise=False
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride=stride
self.padding = padding
self.dilation = dilation
self.depth = depthwise
if self.depth:
self.depthwise = nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels, self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=in_channels, bias=False),
nn.BatchNorm2d(self.in_channels),
nn.ReLU()
)
else:
self.depthwise = nn.Conv2d(self.in_channels, self.in_channels, self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=in_channels, bias=False)
self.pointwise=nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, bias=bias)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
Residual block:
class residualBlock(nn.Module):
def __init__(self, input, output, atrous:list, strides:list, relus:list,
residual=True, last_depth=True):
super(residualBlock, self).__init__()
self.input=input
self.output=output
self.residual = residual
self.atrous = atrous
self.strides = strides
self.relus = relus
self.last_depth = last_depth
if self.atrous == None:
atrous = [1]*3
elif isinstance(self.atrous, int):
atrous = [self.atrous]*3
def block(input, output, padding, dilation, stride=1, relu=False, depth=False):
if relu:
return nn.Sequential(
nn.ReLU(),
SeparableConv2d(input, output, dilation=dilation, stride=stride, bias=False),
)
else:
return SeparableConv2d(input, output, dilation=dilation, stride=stride, bias=False, depthwise=depth)
#block3
#get additional separable convolution instead of max pooling in original xception model
def residualblock(input, output):
return nn.Sequential(
nn.Conv2d(input, output, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(output),
nn.ReLU()
)
self.block1 = block(self.input, self.output, padding=atrous[0], dilation=atrous[0], stride=self.strides[0], relu=self.relus[0])
self.block2 = block(self.output, self.output, padding=atrous[1], dilation=atrous[1], stride=self.strides[1], relu=self.relus[1])
self.block3 = block(output, output, padding=atrous[2], dilation=atrous[2], stride=self.strides[2], relu=self.relus[2], depth = self.last_depth)
self.residualblock = residualblock(input, output)
def forward(self, x):
res = x
#block1
x = self.block1(x)
#block2
x = self.block2(x)
#block3
x = self.block3(x)
#add residual
if self.residual:
resblock = self.residualblock(res)
#resizing x to resblock
x = F.interpolate(x, size=resblock.size()[2:], mode='bilinear', align_corners=True)
x = x + resblock
else:
x = F.interpolate(x, size=res.size()[2:], mode='bilinear', align_corners=True)
x = x + res
return x
Backbone:
class Xception(nn.Module):
def __init__(self, nInputChannels=3, os=16):
super(Xception, self).__init__()
stride_list = None
self.os = os
self.input = nInputChannels
if self.os == 8:
stride_list = [2,1,1]
elif self.os == 16:
stride_list = [2,2,1]
def strideconv(input, output, checkstride):
if checkstride:
return nn.Sequential(
nn.Conv2d(input, output, kernel_size=3, stride=2, bias=False),
nn.ReLU()
)
else:
return nn.Sequential(
nn.Conv2d(input, output, kernel_size=3, bias=False),
nn.ReLU()
)
# Entry flow
self.entry_conv1 = strideconv(self.input, 32, True)
self.entry_conv2 = strideconv(32, 64, False)
self.entry_conv3 = residualBlock(64, 128, atrous=stride_list[0], strides=[1,1,2], relus=[False, True, False])
self.entry_conv4 = residualBlock(128, 256, atrous=stride_list[0], strides=[1,1,2], relus=[True, True, False])
self.entry_conv5 = residualBlock(256, 728, atrous=stride_list[0], strides=[1,1,2], relus=[True, True, False])
# Middle flow
self.mid01 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid02 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid03 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid04 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid05 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid06 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid07 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid08 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid09 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid10 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid11 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid12 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid13 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid14 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid15 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
self.mid16 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
# Exit flow
self.exit_residual = residualBlock(728,1024,stride_list[2], strides=[1,1,2], relus=[True, True, False])
self.exit_conv1 = SeparableConv2d(1024, 1536, kernel_size=3, stride=stride_list[2], bias=False)
self.exit_relu1 = nn.ReLU()
self.exit_conv2 = SeparableConv2d(1536, 1536, kernel_size=3, stride=stride_list[2], bias=False)
self.exit_relu2 = nn.ReLU()
self.exit_conv3 = SeparableConv2d(1536, 2048, kernel_size=3, stride=stride_list[2], bias=False)
def forward(self, x):
# Entry flow
entry_out1 = self.entry_conv1(x)
entry_out2 = self.entry_conv2(entry_out1)
entry_out3 = self.entry_conv3(entry_out2)
entry_out4 = self.entry_conv4(entry_out3)
entry_out5 = self.entry_conv5(entry_out4)
low_level_features = entry_out5
# Middle flow
mid_out01 = self.mid01(entry_out5)
mid_out02 = self.mid02(mid_out01)
mid_out03 = self.mid03(mid_out02)
mid_out04 = self.mid04(mid_out03)
mid_out05 = self.mid05(mid_out04)
mid_out06 = self.mid06(mid_out05)
mid_out07 = self.mid07(mid_out06)
mid_out08 = self.mid08(mid_out07)
mid_out09 = self.mid09(mid_out08)
mid_out10 = self.mid10(mid_out09)
mid_out11 = self.mid11(mid_out10)
mid_out12 = self.mid12(mid_out11)
mid_out13 = self.mid13(mid_out12)
mid_out14 = self.mid13(mid_out13)
mid_out15 = self.mid15(mid_out14)
mid_out16 = self.mid16(mid_out15)
# Exit flow
exit_out1 = self.exit_residual(mid_out16)
exit_out2 = self.exit_conv1(exit_out1)
exit_out3 = self.exit_conv2(exit_out2)
out = self.exit_conv3(exit_out3)
return out, low_level_features
ASPP:
class ASPPConv(nn.Module):
def __init__(self, in_channels, out_channels, dilation):
super(ASPPConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
conv = self.conv(x)
bn = self.bn(conv)
out = self.relu(bn)
return out
class ASPPPooling(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(ASPPPooling, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
size = x.shape[-2:]
x = self.pool(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return F.interpolate(x, size=size, mode="bilinear", align_corners=True)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels = 256):
super(ASPP, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.atrous_rates = atrous_rates
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(self.in_channels, self.out_channels, 1, bias=False),
nn.BatchNorm2d(self.out_channels),
nn.ReLU())
)
rates = self.atrous_rates
for i in range(len(rates)):
modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))
modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))
modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))
modules.append(ASPPPooling(self.in_channels, self.out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * self.out_channels, self.out_channels, 1, bias=False),
nn.BatchNorm2d(self.out_channels),
nn.ReLU(),
nn.Dropout(0.5),
)
def forward(self, x):
_res = []
for conv in self.convs:
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
Decoder:
class Decoder(nn.Module):
def __init__(self, in_channels, num_classes):
super(Decoder, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
#1x1 convolution of low-level features
self.conv1 = nn.Conv2d(self.in_channels, 48, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(48)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
self.drop2 = nn.Dropout(0.5)
#3x3 convolution
self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.relu3 = nn.ReLU()
self.drop3 = nn.Dropout(0.1)
self.conv4 = nn.Conv2d(256, self.num_classes, kernel_size=1, stride=1)
def forward(self, x, low_level_features):
#1x1 convolution of low-level features
low_level_features = self.conv1(low_level_features)
low_level_features = self.bn1(low_level_features)
low_level_features = self.relu1(low_level_features)
#concatenation
#resize low_level_features to x size
x = F.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, low_level_features), dim=1)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.drop2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.drop3(x)
x = self.conv4(x)
return x
DeepLabV3+:
class DeepLabV3_plus(nn.Module):
def __init__(self, num_classes=1, shape=(512, 512),output_stride=16):
super(DeepLabV3_plus, self).__init__()
self.num_classes = num_classes
self.output_stride = output_stride
self.shape = shape
if self.output_stride == 16:
oslist = [6,12,18]
elif self.output_stride == 8:
oslist = [12,24,36]
#backbone
self.backbone = Xception(os=self.output_stride)
#ASPP
self.aspp = ASPP(2048, oslist)
#decoder
self.decoder = Decoder(728, num_classes=self.num_classes)
def forward(self, x):
x, low_level_features = self.backbone(x)
#ASPP
x = self.aspp(x)
#decoder
x = self.decoder(x, low_level_features)
#need resize to 512, 512
x = F.interpolate(x, size=self.shape, mode='bilinear', align_corners=False)
return x
train:
def train(model, train_loader, valid_loader, epoch, val_term):
model = model.to(DEVICE[0])
model.train()
class_weight = []
for i in range(n_classes):
class_weight.append(1) if i != 0 else class_weight.append(0)
for i in range(epoch):
ious = 0
epoch_loss = 0
if not model.training:
model.train()
for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data.float())
target = target.permute(0,3,1,2)
loss_output = loss(output.float(), target.float())
iou_score = iou(output.int(), target.int(), n_classes)
ious += iou_score
iter_loss = loss_output.item()
loss_output.backward()
epoch_loss += iter_loss
if batch_idx%(len(train_loader)//4) == 0:
wandb.log({"loss": iter_loss})
for j in range(n_classes):
wandb.log({f"iou_class{j}": iou_score[j]})
wandb.log({"miou": torch.nanmean(iou_score)})
epoch_ious = ious/len(train_loader)
epoch_miou = torch.nanmean(epoch_ious)
epoch_loss = epoch_loss/len(train_loader)
print(f"Epoch {i}/{epoch}\nLoss: {epoch_loss:.6f}\nclass IoU: {epoch_ious.numpy()}\nmIoU: {epoch_miou.numpy():.6f}")
if i % val_term == val_term-1 or i==0:
val_miou = validation(model, valid_loader)
torch.save(model.state_dict(), 'epoch_{}.pth'.format(i))
optimizer.step()
scheduler.step(val_miou)