Batch-wise beam search in pytorch

8.8k Views Asked by At

I'm trying to implement a beam search decoding strategy in a text generation model. This is the function that I am using to decode the output probabilities.

def beam_search_decoder(data, k):
    sequences = [[list(), 0.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score - torch.log(row[j])]
                all_candidates.append(candidate)
        # sort candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        sequences = ordered[:k]
    return sequences

Now you can see this function is implemented with batch_size 1 in mind. Adding another loop for batch size would make the algorithm O(n^4). It is slow as it is now. Is there any way to improve the speed of this function. My model output is usually of the size (32, 150, 9907) which follows the format (batch_size, max_len, vocab_size)

3

There are 3 best solutions below

0
On

You can use this library

https://pypi.org/project/pytorch-beam-search/

It implements Beam Search, Greedy Search and sampling for PyTorch sequence models.

The following snippet implements a Transformer seq2seq model and uses it to generate predictions.

#pip install pytorch-beam-search
from pytorch_beam_search import seq2seq

# Create vocabularies
# Tokenize the way you need
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary to
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model
model = seq2seq.Transformer(source_index, target_index)    # just a PyTorch model
model.fit(X, Y, epochs = 100)    # basic method included

# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)    # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new) 
output = [target_index.tensor2text(p) for p in predictions]
output
2
On

Below is my implementation, which may be a little bit faster than the for loop implementation.

import torch


def beam_search_decoder(post, k):
    """Beam Search Decoder

    Parameters:

        post(Tensor) – the posterior of network.
        k(int) – beam size of decoder.

    Outputs:

        indices(Tensor) – a beam of index sequence.
        log_prob(Tensor) – a beam of log likelihood of sequence.

    Shape:

        post: (batch_size, seq_length, vocab_size).
        indices: (batch_size, beam_size, seq_length).
        log_prob: (batch_size, beam_size).

    Examples:

        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
        >>> indices, log_prob = beam_search_decoder(post, 3)

    """

    batch_size, seq_length, _ = post.shape
    log_post = post.log()
    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for i in range(1, seq_length):
        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
    return indices, log_prob
0
On
/!\ The most upvoted answer doesn't perform a correct beam-search!

Based on the version proposed by 防暴队大盾, I decided to implement a version of the beam-search algorithm that does not overlook sequences that share initial tokens. This is done by retrieving correct indices from the indices of the flatten array

def beam_search(prediction, k=10):
    batch_size, seq_length, vocab_size = prediction.shape
    log_prob, indices = prediction[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for n1 in range(1, seq_length):
        log_prob_temp = log_prob.unsqueeze(-1) + prediction[:, n1, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index_temp = log_prob_temp.view(batch_size, -1).topk(k, sorted=True)
        idx_begin = index_temp // vocab_size  # retrieve index of start sequence
        idx_concat = index_temp % vocab_size  # retrieve index of new token
        new_indices = torch.zeros((batch_size, k, n1+1), dtype=torch.int64)
        for n2 in range(batch_size):
            new_indices[n2, :, :-1] = indices[n2][idx_begin[n2]]
            new_indices[n2, :, -1] = idx_concat[n2]
        indices = new_indices
    return indices, log_prob

This version assumes that prediction corresponds to the cross-entropy scores, not the probability. Therefore no need to take the log here.

If someone knows how to avoid the inner-most loop with some fancy indexing, one can probably make this even faster.