Beam Search in Python

8.1k Views Asked by At

I am implementing a Seq2Seq model in Keras. However, they have not provided the beam search option in the decoder. Hence, I considered pynlpl's BeamSearch but their documentation on search found here doesn't have any information about how to implement it. Could you please give an example of how beam search can be implemented?

There is a similar answer here: How to implement a custom beam search in TensorFlow? but, its not clear.

2

There are 2 best solutions below

0
On

TensorFlow addons provides beam-search facility. Quoting from the official documentation:

Following is the class-signature:

tfa.seq2seq.BeamSearchDecoder(
    cell: tf.keras.layers.Layer,
    beam_width: int,
    embedding_fn: Optional[Callable] = None,
    output_layer: Optional[tf.keras.layers.Layer] = None,
    length_penalty_weight: tfa.types.FloatTensorLike = 0.0,
    coverage_penalty_weight: tfa.types.FloatTensorLike = 0.0,
    reorder_tensor_arrays: bool = True,
    **kwargs
)

And here is an example:

tiled_encoder_outputs = tfa.seq2seq.tile_batch(
    encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tfa.seq2seq.tile_batch(
    encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tfa.seq2seq.tile_batch(
    sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
    num_units=attention_depth,
    memory=tiled_inputs,
    memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.get_initial_state(
    batch_size=true_batch_size * beam_width, dtype=dtype)
decoder_initial_state = decoder_initial_state.clone(
    cell_state=tiled_encoder_final_state)
0
On

Generally speaking you can do like this:

  1. tile the origin batch in the first dimension by beam_size times including the outputs(encode outputs for the attention and the final states of the encoder as the initial decoder state). Each beam sized repetitive samples are noted beam_i.

  2. do one step of decoding and get the top beam_size * 2 indices and probabilities correspondingly from the vocabulary sized outputs for each beam_i.

  3. keep the previous probabilities of each generated character and calculate the average probability with those probabilities and the last probability you get from the step 2, and get the top beam_size characters.

  4. put those samples encountering a stop_decode symbol into a list and if the length of the list reaches beam size for each beam_i, the beam ends.

But it is a little too abstract and hence you can refer to this (official) example for help. It is implemented by Denny Britz from Google and this very simple one.