I was trying to undertand how does CTC loss work and generated some "input data" and "ground truth data", send them to CTC loss function and got strange results.
For example:
*"|" - is vocab[4] - blank char*
Try #1 "exelent guess":
2 words: "acb" "acc"
2 guesses: "a|ccc|b" "aaa|ccc|cccc"
loss: tensor([16.9247, 18.9556])
Try #2 "exelent match, no blank"
words: "b" "b"
guesses: "b" "b"
loss tensor([22.0504, 22.0504])
Try #3 "absolute miss":
2 words: "bba" "bbb"
2 guesses: "a|ccc|b" "aaa|ccc|cccc"
loss tensor([17.4538, 18.9426])
Loss seems to be almoste the same for all 3 tries. I was expecting to get a low loss on first two tries and a big loss on third try.
What am i doing wrong?
And here:
vocab_test = ['a', 'b', 'c',' ', '|']
vocab_dict_test = {'a':1,'b':2,'c':3,' ':4, '|':5}
vocab_length = 5
batch_size = 2
label_len = 7
input_len = 15
def loss_check(word1, word2, guess1, guess2):
# convert words into digital torches of size needed
input_word1 = word1.ljust(label_len)
guessed_word1 = guess1.ljust(input_len)
input_word_numbers1 = get_numbers(vocab_dict_test, input_word1)
guessed_word_numbers1 = get_numbers(vocab_dict_test, guessed_word1)
input_word2 = word2.ljust(label_len)
guessed_word2 = guess2.ljust(input_len)
input_word_numbers2 = get_numbers(vocab_dict_test, input_word2)
guessed_word_numbers2 = get_numbers(vocab_dict_test, guessed_word2)
print('words:', '"' + word1 + '" / "' + word2 + '"', 'converted:', input_word_numbers1, input_word_numbers2)
print('guesses:', '"' + guess1 + '" / "' + guess2 + '"', 'converted:', guessed_word_numbers1, guessed_word_numbers2)
# special torches for loss func
input_len_size = torch.IntTensor([input_len] * batch_size).to(device)
label_len_size = torch.IntTensor([label_len] * batch_size).to(device)
truth = torch.from_numpy(np.array([input_word_numbers1,input_word_numbers2])).float()
#print(truth)
#print(truth.size())
logits = [torch.from_numpy(np.array(guessed_word_numbers1)),torch.from_numpy(np.array(guessed_word_numbers2))]
logits_converted = [[[1 if (logits[j][h] == i+1) else 0 for i in range(vocab_length)] for h in range(input_len)] for j in range(batch_size)]
logits_converted = torch.from_numpy(np.array(logits_converted)).float()
#print(logits_converted)
Softmax = nn.LogSoftmax(dim=2)
logits_log = Softmax(logits_converted)
#print(logits_log)
logits_log_formatted = logits_log.transpose(1,0)
#print(logits_log_formatted)
#print(logits_log_formatted.size())
loss_fn = nn.CTCLoss(blank=4, zero_infinity=True, reduction='none').to(device)
#log_probs, targets, input_lengths, target_lengths, self.blank,
loss = loss_fn(logits_log_formatted, truth, input_len_size, label_len_size)
return loss
print(loss_check('acb', 'acc','a|ccc|b', 'aaa|ccc|cccc'))
print(loss_check('bba', 'bbb','a|ccc|b', 'aaa|ccc|cccc'))
print(loss_check('b', 'b','b', 'b'))