Suppose you have stacked two sequences of 3-dimensional embeddings into a single ragged tensor:
import tensorflow as tf
def foo(*args):
n_elements = tf.reduce_prod(args)
return tf.range(n_elements, dtype=tf.float32).reshape(args)
c = tf.ragged.stack((foo(2, 3), foo(5, 3)), axis=0)
assert c.shape == [2, None, None]
How to cast c
to shape [2, None, 3]
(because you know this tensor is of this shape)?
Try using
tf.RaggedTensor.from_row_splits
: