I was trying to use Convolution Vision Transformer as Encoder and UNet as Decoder. Here is my model architecture

from torchvision.models import resnet34, ResNet34_Weights

# Define the CvT encoder part here (Example with ResNet backbone)
class CvTEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        # Define the output channels in the order of the encoder's forward pass
        # Assuming that these are the channels after each layer/block
        self.out_channels = [64, 64, 128, 256, 512]  # from top layer to bottom layer
        self.output_stride = 32

    def forward(self, x):
        # Forward pass through CvT encoder
        # You need to adjust this forward method to collect the feature maps from the intermediate layers.
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x1 = self.backbone.maxpool(x)

        x2 = self.backbone.layer1(x1)
        x3 = self.backbone.layer2(x2)
        x4 = self.backbone.layer3(x3)
        x5 = self.backbone.layer4(x4)

        features = [x2, x3, x4, x5]
        return features

class CvTUnet(SegmentationModel):
    def __init__(self):
        super().__init__()
        
        self.encoder = CvTEncoder()
        
        # The `encoder_channels` should be reversed because the UnetDecoder will use them in reverse order
        # It starts from the deepest encoder features to the shallowest
        reversed_encoder_channels = self.encoder.out_channels[::-1]

        # Define the decoder channels such that each decoder block expects the number of channels
        # that is the sum of the encoder's corresponding output and the previous decoder's output
        # Assuming that the first decoder block receives only the last encoder's output (512 channels)
        # and does not concatenate it with any previous decoder output
        self.decoder = UnetDecoder(
            encoder_channels=reversed_encoder_channels,
            decoder_channels=[512, 256, 128, 64, 32],  # This should be aligned with the output of the encoder
            n_blocks=len(reversed_encoder_channels),
            use_batchnorm=True,
            center=False  # Set to False if the last encoder block has downsampling
        )
        self.segmentation_head = nn.Conv2d(
            in_channels=32,  # This should match the last decoder_channels entry
            out_channels=1,  # Assuming binary segmentation
            kernel_size=3, 
            padding=1
        )
        self.classification_head = None
        self.initialize()

I believe this error is because of mismatch in output shape of encoder and input shape of decoder but I see no error there.

I tried to use different shape parameters. Also I am new to this so I mostly just hit and try.

0

There are 0 best solutions below