pytorch CTC loss works strange on my test data

51 Views Asked by At

I was trying to undertand how does CTC loss work and generated some "input data" and "ground truth data", send them to CTC loss function and got strange results.

For example:

*"|" - is vocab[4] - blank char*

Try #1 "exelent guess":
2 words: "acb"  "acc" 
2 guesses: "a|ccc|b"  "aaa|ccc|cccc" 

loss: tensor([16.9247, 18.9556])

Try #2 "exelent match, no blank"
words: "b"  "b" 
guesses: "b"  "b" 

loss tensor([22.0504, 22.0504])

Try #3 "absolute miss":
2 words: "bba"  "bbb" 
2 guesses: "a|ccc|b"  "aaa|ccc|cccc" 

loss tensor([17.4538, 18.9426])

Loss seems to be almoste the same for all 3 tries. I was expecting to get a low loss on first two tries and a big loss on third try.

What am i doing wrong?

Code here https://colab.research.google.com/drive/1KXgTuNqg-uO-oDvdwDDc-3UbnO7BvWUI?usp=sharing#scrollTo=yFzLIj663DQr

And here:

vocab_test = ['a', 'b', 'c',' ', '|']
vocab_dict_test = {'a':1,'b':2,'c':3,' ':4, '|':5}

vocab_length = 5
batch_size = 2
label_len = 7
input_len = 15

def loss_check(word1, word2, guess1, guess2):

    # convert words into digital torches of size needed
    input_word1 = word1.ljust(label_len)
    guessed_word1 = guess1.ljust(input_len)

    input_word_numbers1 = get_numbers(vocab_dict_test, input_word1)
    guessed_word_numbers1 = get_numbers(vocab_dict_test, guessed_word1)

    input_word2 = word2.ljust(label_len)
    guessed_word2 = guess2.ljust(input_len)

    input_word_numbers2 = get_numbers(vocab_dict_test, input_word2)
    guessed_word_numbers2 = get_numbers(vocab_dict_test, guessed_word2)

    print('words:', '"' + word1 + '" / "' + word2 + '"', 'converted:', input_word_numbers1, input_word_numbers2)
    print('guesses:', '"' + guess1 + '" / "' + guess2 + '"', 'converted:', guessed_word_numbers1, guessed_word_numbers2)

    # special torches for loss func
    input_len_size = torch.IntTensor([input_len] * batch_size).to(device)
    label_len_size = torch.IntTensor([label_len] * batch_size).to(device)

    truth = torch.from_numpy(np.array([input_word_numbers1,input_word_numbers2])).float()
    #print(truth)
    #print(truth.size())

    logits = [torch.from_numpy(np.array(guessed_word_numbers1)),torch.from_numpy(np.array(guessed_word_numbers2))]
    logits_converted = [[[1 if (logits[j][h] == i+1) else 0 for i in range(vocab_length)] for h in range(input_len)] for j in range(batch_size)]
    logits_converted = torch.from_numpy(np.array(logits_converted)).float()
    #print(logits_converted)
    Softmax = nn.LogSoftmax(dim=2)
    logits_log = Softmax(logits_converted)
    #print(logits_log)

    logits_log_formatted = logits_log.transpose(1,0)
    #print(logits_log_formatted)
    #print(logits_log_formatted.size())
    loss_fn = nn.CTCLoss(blank=4, zero_infinity=True, reduction='none').to(device)

    #log_probs, targets, input_lengths, target_lengths, self.blank,
    loss = loss_fn(logits_log_formatted, truth, input_len_size, label_len_size)

    return loss


print(loss_check('acb', 'acc','a|ccc|b', 'aaa|ccc|cccc'))
print(loss_check('bba', 'bbb','a|ccc|b', 'aaa|ccc|cccc'))
print(loss_check('b', 'b','b', 'b'))
0

There are 0 best solutions below