Mask for input is given, its shape is [batch size, no of timesteps]. From this, I need to collect X number of embeddings of shape [batch size, timestep index, embedding size] such that they are grouped just before each grouping of False.
say mask of a batch size of 1 is T,T,T,F,F,F,F,|T,T,F,F,F,F,F and X=2 (by '|', I assumed a break which indicates a row split length=7), then should get list of concatenated embeddings given by indices (1,2) (8, 9).
Should be able to replicate the above when batch size is variable without having to do individually for each batch as my batch size is pretty high. Output should be [ [ (1,2) , (.,.),.. for other batches at first row split (0:7) ] , [ (8,9) , (.,.) for other batches at second row split (7:14)], ..]