How does tf.contrib.seq2seq.gather_tree work?

376 Views Asked by At

How exactly does gather_tree in contrib.seq2seq work? I can see that it takes the predicted ids and beam parent ids and somehow returns the final beams, but what's actually going underneath the hood? There doesn't seem to be any Python code base I could examine to figure it out. The API isn't very explanatory;

Is there any code source for tf.contrib.seq2seq.gather_tree? I am using TensorFlow 1.3 and looking inside gen_beam_search_ops.py doesn't seem helpful.

1

There are 1 best solutions below

0
On

The codes are detailed as follows:

def gather_tree_py(values, parents):
  """Gathers path through a tree backwards from the leave nodes. Used
  to reconstruct beams given their parents."""

  beam_length = values.shape[0]
  num_beams = values.shape[1]
  res = np.zeros_like(values)
  res[-1, :] = values[-1, :]
  for beam_id in range(num_beams):
    parent = parents[-1][beam_id]
    for level in reversed(range(beam_length - 1)):
      res[level, beam_id] = values[level][parent]
      parent = parents[level][parent]
  return np.array(res).astype(values.dtype)


def gather_tree(values, parents):
  """Tensor version of gather_tree_py"""

  res = tf.py_func(
      func=gather_tree_py, inp=[values, parents], Tout=values.dtype)
  res.set_shape(values.get_shape().as_list())
  return res

github: seq2seq beam_search