Assuming I have an environment where an object can be at a position consisting of x,y with x=(0,3) and y=(0,3). I have images at all these positions and generate a dataset where all images are paired with all others. The images look like the following (without the x in the image):
I also have labels that indicate if the x-position is the same or not for a pair. I try to train a siamese network that is able to reliably predict those labels. However, My siamese network is not learning at all. Is a siamese network not a good choice for such a task? Or is my implementation faulty?
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
# Define the architecture for each branch of the siamese network
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=10),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=7),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(128, 128, kernel_size=4),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=4),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 256),
)
def forward_one(self, x):
# Forward pass for one branch of the siamese network
output = self.cnn(x)
return output
def forward(self, input1, input2):
# Forward pass for both branches and compute the L1 distance
output1 = self.forward_one(input1)
output2 = self.forward_one(input2)
# L1 distance
distance = F.pairwise_distance(output1, output2)
return distance
train_dataset = SiameseDataset(image_folder='path/to/images', transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
siamese_net = SiameseNetwork()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
for batch in train_loader:
input1, input2, labels = batch
optimizer.zero_grad()
outputs = siamese_net(input1, input2)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')