Simamese Network - Detect if x position of images are the same

35 Views Asked by At

Assuming I have an environment where an object can be at a position consisting of x,y with x=(0,3) and y=(0,3). I have images at all these positions and generate a dataset where all images are paired with all others. The images look like the following (without the x in the image):

enter image description here

enter image description here

enter image description here

I also have labels that indicate if the x-position is the same or not for a pair. I try to train a siamese network that is able to reliably predict those labels. However, My siamese network is not learning at all. Is a siamese network not a good choice for such a task? Or is my implementation faulty?

class SiameseNetwork(nn.Module):
def __init__(self):
    super(SiameseNetwork, self).__init__()

    # Define the architecture for each branch of the siamese network
    self.cnn = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=10),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Conv2d(64, 128, kernel_size=7),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Conv2d(128, 128, kernel_size=4),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.Conv2d(128, 256, kernel_size=4),
        nn.ReLU(inplace=True),
        nn.Flatten(),
        nn.Linear(256 * 6 * 6, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, 256),
    )

def forward_one(self, x):
    # Forward pass for one branch of the siamese network
    output = self.cnn(x)
    return output

def forward(self, input1, input2):
    # Forward pass for both branches and compute the L1 distance
    output1 = self.forward_one(input1)
    output2 = self.forward_one(input2)

    # L1 distance
    distance = F.pairwise_distance(output1, output2)

    return distance



train_dataset = SiameseDataset(image_folder='path/to/images', transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

siamese_net = SiameseNetwork()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    for batch in train_loader:
        input1, input2, labels = batch

        optimizer.zero_grad()
        outputs = siamese_net(input1, input2)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
0

There are 0 best solutions below