I am building a siamese model using Lstm, I have trained and tested the model but I condn’t inference it on sigle sample
Here’s the model
class SiameseLstm(nn.Module):
def __init__(self):
super(SiameseLstm, self).__init__()
self.embedding_model = nn.Embedding(model_vocab_size, embedding_dim)
self.embedding_model.weight.data.copy_(embedding_matrix_model)
self.embedding_student = nn.Embedding(student_vocab_size, embedding_dim)
self.embedding_student.weight.data.copy_(embedding_matrix_student)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)
self.flat = nn.Flatten()
self.dropout = nn.Dropout(0.1)
self.dense = nn.Linear(544, 1)
self.activation = nn.Sigmoid()
def forward_model(self, input):
embedded = self.embedding_model(input)
out,(h, c) = self.lstm(embedded)
#hidden = torch.cat((h[-2,:,:], h[-1,:,:]), dim = 1)
output = out[:, -1, :]
output = self.flat(output)
return output
def forward_student(self, input):
embedded = self.embedding_student(input)
out,(h, c) = self.lstm(embedded)
#hidden = torch.cat((h[-2,:,:], h[-1,:,:]), dim = 1)
output = out[:, -1, :]
output = self.flat(output)
return output
def forward(self, inp1, inp2):
out1 = self.forward_model(inp1)
out2 = self.forward_student(inp2)
x3 = torch.subtract(out1, out2)
x3 = torch.multiply(x3, x3)
x1_ = torch.multiply(out1, out1)
x2_ = torch.multiply(out2, out2)
x4 = torch.subtract(x1_, x2_)
x5 = torch.cdist(out1, out2)
merged = torch.concatenate((x5, x4, x3), dim=-1)
merged = self.dropout(merged)
merged = self.dense(merged)
merged = self.activation(merged)
return merged
train function:
Model.train()
for i, (inp1, inp2, label) in enumerate(train_dataloader):
inp1 = inp1.cuda()
inp2 = inp2.cuda()
label = label.cuda()
optimizer.zero_grad()
output = Model(inp1, inp2)
#print('output', output)
#print('label', label)
loss = criterion(output, label)
acc = calculate_acc(output, label)
TrainLoss += loss.item()
TrainAcc += acc.item()
loss.backward()
optimizer.step()
And this is the error in inference:
runtimeerror: mat1 and mat2 shapes cannot be multiplied (1x513 and 544x1)
I tried to change the batch size and it worked but I want only one sample