I am working on a triplet loss based model for text embedding.
Short description:
I have a database about online shop, I need to find the suitble product when users enter a text on search bar. I want a model work better than matching string and can understand user's mind. I define a triplet Network like that: My input is (query text [anchor], next product user view after searching [positive], a random product [negative]). I build an encoder model based on bi-LSTM and tried to train the distance between anchor and positive is minimum and the distance between anchor and negative is maximun, and use triplet loss.
I tried to implement this network enter image description here
refer : https://arxiv.org/pdf/2104.08558.pdf
My encoderNet
class encodeNet(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers,
bidirectional, dropout):
#Constructor
super().__init__()
#embedding layer
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.directions = bidirectional
#lstm layer
self.lstm = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
self.fc1 = nn.Linear(hidden_dim * 2, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 512)
self.dropout = nn.Dropout(p=0.3)
self.batchnorm1 = nn.BatchNorm1d(1024)
self.batchnorm2 = nn.BatchNorm1d(512)
self.relu = nn.ReLU()
self.P1 = nn.MaxPool1d(2, stride=2)
self.act = nn.Sigmoid()
def LM(self, text):
embedded = self.embedding(text)
packed_output, (hidden, cell) = self.lstm(embedded)
#concat the final forward and backward hidden state
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
hidden = self.dropout(hidden)
hidden = self.fc1(hidden)
hidden = self.batchnorm1(hidden)
hidden = self.relu(hidden)
hidden = self.fc2(hidden)
hidden = self.batchnorm2(hidden)
hidden = self.fc3(hidden)
return hidden
def forward(self, anchor, pos, neg):
anchor = self.LM(anchor)
pos = self.LM(pos)
neg = self.LM(neg)
anchor = self.P1(anchor)
pos = self.P1(pos)
neg = self.P1(neg)
return anchor, pos,neg
And I used loss function triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) by pytorch framework.
The result, I saw that in training dataset, loss value decreased to so small and so fast but in valid dataset loss value didn't present any meaning, it was up and down like random.
I trained model with 8572 vocabs, 81822 training samples, Is it too small dataset?
Can you help me and what is the issue in my solution?
I suggest you to use Hard-Triplets. You can learn more about this in FaceNet paper. I hope that it can be helpful for you.