Can prefix beam search commonly used in speech recognition with CTC be implemented in such a simpler way?

504 Views Asked by At

I am learning about speech recognition recently, and I have learned that the idea of prefix beam search is to merge paths with the same prefix, such as [1,1,_] and [_,1,_] (as you can see, _ indicates blank mark).

Based on this understanding, I implemented a version of mine, which can be simplified using pseudo code like this:

def prefix_beam_search(y, beam_size, blank):
    seq_len, n_class = y.shape
    logY = np.log(y)
    beam = [([], 0)]

    for t in range(seq_len):
        buff = []
        for prefix, p in beam:
            for i in range(n_class):
                new_prefix = list(prefix) + [i]
                new_p = p + logY[t][i]
                buff.append((new_prefix, new_p))

        # merge the paths with same prefix'
        new_beam = defaultdict(lambda: ninf)
        for prefix, p in buff:
            # 'norm_prefix' can simplify the path, [1,1,_,2] ==> [1,2]
            # However, the ending 'blank' is retained, [1,1,_] ==> [1,_]
            prefix = norm_prefix(prefix, blank)
            new_beam[prefix] = logsumexp(new_beam[prefix], p)

        # choose the best paths
        new_beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True)
        beam = new_beam[: beam_size]

    return beam

But most of the versions I found online (according to the paper) are like this:

def _prefix_beam_decode(y, beam_size, blank):
    T, V = y.shape
    log_y = np.log(y)
    beam = [(tuple(), (0, ninf))]

    for t in range(T):
        new_beam = defaultdict(lambda: (ninf, ninf))
        for prefix, (p_b, p_nb) in beam:
            for i in range(V):
                p = log_y[t, i]
                if i == blank:
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                end_t = prefix[-1] if prefix else None
                new_prefix = prefix + (i,)
                new_p_b, new_p_nb = new_beam[new_prefix]
                if i != end_t:
                    new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                else:
                    new_p_nb = logsumexp(new_p_nb, p_b + p)
                new_beam[new_prefix] = (new_p_b, new_p_nb)
                if i == end_t:
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_nb = logsumexp(new_p_nb, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)

        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]
    return beam

The results of the two are different, and my version tends to return longer strings. And I don't quite understand the main two aspects:

  1. Are there any details of my version that are not thoughtful?
  2. The common version while generate new prefix by new_prefix = prefix + (i,) regardless of whether the end of the previous are the same as the given 's'. For example, the old prefix is [a,a,b] and when a new character s is added, both [a,a,b] and [a,a,b,b] are saved. What is the purpose if this? And does it cause double counting?

Looking forward to your answer, thanks in advance!

1

There are 1 best solutions below

0
On

When you choose the best paths in your code, you don't want to differentiate between [1,_] and [1] since both correspond to the same prefix [1].

If you have for example:

[1], [1,_], [1,2]

then you want the probability of [1] and [1,_] both to have the sum of the two.

probability([1]) = probability([1])+probability([1,_])

probability([1,_]) = probability([1])+probability([1,_])

And after sorting with these probabilities, you may want to keep so many that the number of true prefixes is beam_size.

For example, you have [1], [1,_], [2], [3].

Of which probabilities are: 0.1, 0.08, 0.11, 0.15

Then the probabilities with which you want to sort them are:

0.18, 0.18, 0.11, 0.15, respectively (0.18 = 0.1 + 0.08)

Sorted: [1]:0.18, [1,_]: 0.18, [3]:0.15, [2]:0.11

And if you have beam_size 2, for example, then you may want to keep

[1], [1,_] and [3] so that you have 2 prefixes in your beam, because [1] and [1,_] count as the same prefix (as long as the next character is not 1 - that's why we keep track of [1] and [1,_] separately).