soft dice score output size target size error

36 Views Asked by At

I want to develop a model for image segmentation. During running the code I got the following error:


File c:\Users\MSB\AppData\Local\Programs\Python\Python39\lib\site-packages\segmentation_models_pytorch\losses\_functional.py:179, in soft_dice_score(output, target, smooth, eps, dims)
    172 def soft_dice_score(
    173     output: torch.Tensor,
    174     target: torch.Tensor,
   (...)
    177     dims=None,
    178 ) -> torch.Tensor:
--> 179     assert output.size() == target.size()
    180     if dims is not None:
    181         intersection = torch.sum(output * target, dim=dims)

AssertionError: 

I checked my code, both prob_mask and target have a size of [1, 4, 1024, 1024]:

prob_mask= torch.Size([1, 4, 1024, 1024])
targrt= torch.Size([1, 4, 1024, 1024])

Then, I searched on the internet and found the following comments:

I checked the source code. The issue is that mode="multiclass" expects a non-one hot encoded target tensor, that is then one hot encoded. When feeding in a one hot encoded tensor, the two above issues occur:

  1. F.one_hot() requires a long tensor.
  2. After one hot encoding the already one hot encoded tensor the shapes of preds and target tensor do not match anymore (causing the failed assertion).

There are two options. You can use a non-one hot encoded target tensor (as expected by the implementation, still requires a long tensor). Or you can use a one hot encoded target with mode="multilabel". But this only works if you apply an activation before calculating the loss (requires to set from_logits=False).

Multilabel uses the same shape (N, C, H, W) as a one hot encoded target tensor. The difference is that the C-dimension can have more than a single 1 (multi-hot? encoding / each pixel can represent more than a single class). For each C-layer a binary classification is done. To avoid the latter you have to apply your own activation beforehand.

but I don't know how to apply the above comment in my code:

    def shared_step(self, batch, stage):
        # print("batch")
        # print(batch[0])
        image = batch['image']  ;  #   batch["image"]
        mask = batch['mask']

        # Print the shape of the input tensor before it is passed to the model
        print()
        print(f"Input Image Shape: {image.shape}")
        print(f"Input Mask Shape: {mask.shape}")
        print()
        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        # print("image.shape[2:]")
        # print(image.shape[2:])
        h, w = image.shape[2:]
        print(f"Original Image Dimensions: ({h}, {w})")

        assert h % 32 == 0 or w % 32 == 0


        mask = batch['mask']  ;  #batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        # print("****************mask.shape*******************")
        # print(mask.shape)
        # print()
        assert mask.ndim ==4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Apply softmax activation to logits_mask
        prob_mask = softmax(logits_mask, dim=1)
        
         # Ensure that the model output and target mask have the same number of channels
        # assert logits_mask.size(1) == self.num_classes
        assert logits_mask.size(1) == self.num_classes




        # Convert the one-hot encoded target to multi-label format
        target = batch['mask']

        print("befor loss")
        print("prob_mask=",prob_mask.shape)
        print('targrt=',target.shape)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(prob_mask, target)
        
        print('***************')
        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.log_softmax(dim=1)
        pred_mask = (prob_mask > 0.5).float()
        print('pred_mask=',pred_mask.shape)
        
        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="multiclass")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }
0

There are 0 best solutions below