Heat map of an image using ReaNet

20 Views Asked by At

I need to make a heat map for one cat from one.jpg to two.jpg (contains many cats) using ResNet and F.conv2d

I have the code, but the heatmap turns out completely black and I don't really understand why. I cut the model down to the fourth layer to get the working features out If you can help me understand what I'm doing wrong or what direction I should think in, I'd be very grateful because I'm completely confused.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

model = models.resnet18(pretrained=True)
model = nn.Sequential(*(list(model.children())[:4]))
model.eval()

def preprocess_image(image_path):
    image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = preprocess(image).unsqueeze(0)
    return image

one = preprocess_image('one.jpg')
two = preprocess_image('two.jpg')

one_features = model(one)
two_features = model(two)

heatmap = F.conv2d(one_features, two_features)
heatmap = heatmap.squeeze(0).detach().numpy()

plt.imshow(heatmap, cmap='hot')
plt.axis('off')
plt.show()
0

There are 0 best solutions below