I was trying to use Convolution Vision Transformer as Encoder and UNet as Decoder. Here is my model architecture
from torchvision.models import resnet34, ResNet34_Weights
# Define the CvT encoder part here (Example with ResNet backbone)
class CvTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.backbone = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
# Define the output channels in the order of the encoder's forward pass
# Assuming that these are the channels after each layer/block
self.out_channels = [64, 64, 128, 256, 512] # from top layer to bottom layer
self.output_stride = 32
def forward(self, x):
# Forward pass through CvT encoder
# You need to adjust this forward method to collect the feature maps from the intermediate layers.
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x1 = self.backbone.maxpool(x)
x2 = self.backbone.layer1(x1)
x3 = self.backbone.layer2(x2)
x4 = self.backbone.layer3(x3)
x5 = self.backbone.layer4(x4)
features = [x2, x3, x4, x5]
return features
class CvTUnet(SegmentationModel):
def __init__(self):
super().__init__()
self.encoder = CvTEncoder()
# The `encoder_channels` should be reversed because the UnetDecoder will use them in reverse order
# It starts from the deepest encoder features to the shallowest
reversed_encoder_channels = self.encoder.out_channels[::-1]
# Define the decoder channels such that each decoder block expects the number of channels
# that is the sum of the encoder's corresponding output and the previous decoder's output
# Assuming that the first decoder block receives only the last encoder's output (512 channels)
# and does not concatenate it with any previous decoder output
self.decoder = UnetDecoder(
encoder_channels=reversed_encoder_channels,
decoder_channels=[512, 256, 128, 64, 32], # This should be aligned with the output of the encoder
n_blocks=len(reversed_encoder_channels),
use_batchnorm=True,
center=False # Set to False if the last encoder block has downsampling
)
self.segmentation_head = nn.Conv2d(
in_channels=32, # This should match the last decoder_channels entry
out_channels=1, # Assuming binary segmentation
kernel_size=3,
padding=1
)
self.classification_head = None
self.initialize()
I believe this error is because of mismatch in output shape of encoder and input shape of decoder but I see no error there.
I tried to use different shape parameters. Also I am new to this so I mostly just hit and try.