Pytorch keypoint detector not learning

148 Views Asked by At

I'm making a keypoint detector based on the MPII dataset and for some reason no matter what i do, my model doesnt learn and will always output the same values for all pixels on all images.

To start, I've converted the image to grayscale, I've normalised the input image between 0 and 1, generated a bounding box to crop around the person to reduce image size without too much data loss, scale the joint coordinate values to the new cropped image, then rescale the image to 96x96 while also scaling the coordinate values to the new size.

This is one of the images that get passed into the network later on, I've applied the target keypoints to the image to show that the normalisation and data scaling is correct.

For the target tensor, i create a heatmap per joint pair coordinate using this code:

def create_gaussian_map(center, output_shape, sigma):
    height, width = output_shape
    x = torch.arange(0, width, 1, dtype=torch.float32, device=center.device)
    y = torch.arange(0, height, 1, dtype=torch.float32, device=center.device)
    x, y = torch.meshgrid(x, y)
    x0, y0 = center[0], center[1]
    if(x0==-1 or y0==-1 ):
        return torch.zeros((New_width,New_height), dtype=torch.float32)
    map=torch.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    return map

Then i add it to a tensor, where at the same time using the previously generated heatmap, i create a weight where i set all values where the keypoint exists in the heatmap to 1, and 0 to the background per heatmap.

for i in range(batch_size):
                for j in range(16):
                    joint_center = regression_target[i, j, :]
                    #print(joint_center)
                    joint_map = create_gaussian_map(joint_center, (New_width, New_height), sigma=3)
                    weight=create_weights(joint_map)
                    weight_ten[i,j,:,:]=weight
                    joint_mapss[i, j, :, :] = joint_map

So i get a heatmap and weights of size (batch_size,N_joints,width,height) that I then reshape to (batch_size,N_jointswidthheight) for loss calculation

Here is the Model:

class HourglassFCN(nn.Module):
    def __init__(self, in_channels=1, out_channels=16):
        super(HourglassFCN, self).__init__()
        input_height, input_width = New_height, New_width
        output_height, output_width = input_height, input_width
        n = 32*6
        nClasses = 16
        Block1 = 64
        Block2 = 128 
        self.conv1 = nn.Conv2d(1, Block1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(Block1, Block1, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(Block1, Block2, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(Block2, Block2, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(Block2, Block2, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(Block2, Block2, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck1 = nn.Conv2d(Block2, n, kernel_size=(input_height//4, input_width//4), padding='same')
        self.bottleneck2 = nn.Conv2d(n, n, kernel_size=1, padding='same')
        self.convT = nn.ConvTranspose2d(n, nClasses, kernel_size=4, stride=4, padding=0)
        self.conv_last = nn.Conv2d(n,nClasses, kernel_size=1, padding=0)
 
    def forward(self, x):
        
        
        
        x = nn.RReLU()(self.conv1(x))
        x = nn.RReLU()(self.conv2(x))
   
        x = self.pool1(x)
        x = nn.RReLU()(self.conv3(x))
        x = nn.RReLU()(self.conv4(x))
    
        x = self.pool2(x)
        x = nn.RReLU()(self.bottleneck1(x))
        x = nn.RReLU()(self.bottleneck2(x))
        x = self.convT(x)
   
        x = x.reshape(x.shape[0],x.shape[1] * x.shape[2]*x.shape[3])
        output = x 
        return output

During loss calculation, I'm using MSELoss and since the pytorch version doesn't accept weights, i first set MSELoss to use reduction='none', then first calculate the loss then multiply the loss with the weights and return its mean

gtconst=gt_heatmap*10
loss=self.regression_loss_fn(regression_output,gtconst)
return torch.mean(loss*weights)

And thus I end up with an output during training that is the same for all inputs on all epochs. The Pred_re image is the output from the model during training, The rest are self explanatory

Now no matter what i try, I can never get the network to learn, The output will always look like in the image, or well at least close to it.

IF i use the weights differently, or none at all, I get an insaenly low loss around 0.0003 on the first batch. Using this method above, I get a loss of 0.3 on the first batch, and it will never lower. My validation values are also around 0.02 for any method.

Now, I've tried different networks, simpler networks, more complex networks, different learning rates, Different activation functions (ex: sigmoid), different loss functions, Using the weights differently, so as its not 1 where the keypoint exists and 0 in the background, but instead that every value in the specific heatmap is set to 1 if it has a keypoint recorded in it and 0 if its missing.

Now I'm completely lost and i probably messed up something basic but i've spent too much time on this one problem and read too many different articles that now i just can't wrap my head around it so i would appreciate if someone could point me in the right direction.

0

There are 0 best solutions below