Can prefix beam search commonly used in speech recognition with CTC be implemented in such a simpler way?

I am learning about speech recognition recently, and I have learned that the idea of prefix beam search is to merge paths with the same prefix, such as [1,1,_] and [_,1,_] (as you can see, _ indicates blank mark).

Based on this understanding, I implemented a version of mine, which can be simplified using pseudo code like this:

def prefix_beam_search(y, beam_size, blank):
    seq_len, n_class = y.shape
    logY = np.log(y)
    beam = [([], 0)]

    for t in range(seq_len):
        buff = []
        for prefix, p in beam:
            for i in range(n_class):
                new_prefix = list(prefix) + [i]
                new_p = p + logY[t][i]
                buff.append((new_prefix, new_p))

        # merge the paths with same prefix'
        new_beam = defaultdict(lambda: ninf)
        for prefix, p in buff:
            # 'norm_prefix' can simplify the path, [1,1,_,2] ==> [1,2]
            # However, the ending 'blank' is retained, [1,1,_] ==> [1,_]
            prefix = norm_prefix(prefix, blank)
            new_beam[prefix] = logsumexp(new_beam[prefix], p)

        # choose the best paths
        new_beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True)
        beam = new_beam[: beam_size]

    return beam

But most of the versions I found online (according to the paper) are like this:

def _prefix_beam_decode(y, beam_size, blank):
    T, V = y.shape
    log_y = np.log(y)
    beam = [(tuple(), (0, ninf))]

    for t in range(T):
        new_beam = defaultdict(lambda: (ninf, ninf))
        for prefix, (p_b, p_nb) in beam:
            for i in range(V):
                p = log_y[t, i]
                if i == blank:
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                end_t = prefix[-1] if prefix else None
                new_prefix = prefix + (i,)
                new_p_b, new_p_nb = new_beam[new_prefix]
                if i != end_t:
                    new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    new_p_nb = logsumexp(new_p_nb, p_b + p)
                new_beam[new_prefix] = (new_p_b, new_p_nb)
                if i == end_t:
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_nb = logsumexp(new_p_nb, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)

        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]
    return beam

The results of the two are different, and my version tends to return longer strings. And I don't quite understand the main two aspects:

  1. Are there any details of my version that are not thoughtful?
  2. The common version while generate new prefix by new_prefix = prefix + (i,) regardless of whether the end of the previous are the same as the given 's'. For example, the old prefix is [a,a,b] and when a new character s is added, both [a,a,b] and [a,a,b,b] are saved. What is the purpose if this? And does it cause double counting?

Looking forward to your answer, thanks in advance!


When you choose the best paths in your code, you don't want to differentiate between [1,_] and [1] since both correspond to the same prefix [1].

If you have for example:

[1], [1,_], [1,2]

then you want the probability of [1] and [1,_] both to have the sum of the two.

probability([1]) = probability([1])+probability([1,_])

probability([1,_]) = probability([1])+probability([1,_])

And after sorting with these probabilities, you may want to keep so many that the number of true prefixes is beam_size.

For example, you have [1], [1,_], [2], [3].

Of which probabilities are: 0.1, 0.08, 0.11, 0.15

Then the probabilities with which you want to sort them are:

0.18, 0.18, 0.11, 0.15, respectively (0.18 = 0.1 + 0.08)

Sorted: [1]:0.18, [1,_]: 0.18, [3]:0.15, [2]:0.11

And if you have beam_size 2, for example, then you may want to keep

[1], [1,_] and [3] so that you have 2 prefixes in your beam, because [1] and [1,_] count as the same prefix (as long as the next character is not 1 - that's why we keep track of [1] and [1,_] separately).