Training torch.TransformerDecoder with causal mask

386 Views Asked by At

I use torch.TransformerDecoder to generate a sequence, where each next token depends on itself and first 2 tokens [CLS] and first predicted one. So, steps of execution on inference, that i need:

  1. Start from sequence [CLS], [MASK], add positional embeddings and generate first predict, taking into consideration itself and [CLS] token (just causal mask)enter image description here
  2. Put few (i know the number based on first predict) [MASK] tokens to sequence [CLS], [PRED1].
  3. Suppose have sequence [CLS], [PRED1], [MASK], [MASK], [MASK], [MASK].
  4. Add position embeddings to them and compute tokens based on [CLS], [PRED1] and itself, so i suppose that attention mask should look like: enter image description here
  5. Do predict. Doing it, i see that [PRED1] is correct, one class after is also correct (output of first [MASK] after [PRED1]) and all other masks are just repeated of last predicted (output of first [MASK]). Look like this: [PRED1], [PRED2], [PRED2], [PRED2], [PRED2].

Steps of sequential proces:

  1. the same
  2. put only one mask token
  3. predict, so we get sequence [CLS], [PRED1], [PRED2]
  4. only then add new mask token [CLS], [PRED1], [PRED2], [MASK] and than do predict -> [CLS], [PRED1], [PRED2], [PRED3], where 3 is correct.
  5. So if i repeat till the end it works.

But i need parallel processing of all prediction starting from the second.

How it was trained:

  1. In Dataset i get [CLS], [PRED1], [PRED2], [PRED3], [MASK]; also mask of what i going to predict masked_mask = [0, 0, 0, 0, 1] - the last one; and mask of what i see (it's always two first positions 1 + masked_mask) mask = [1, 1, 0, 0, 0] + masked_mask = [1, 1, 0, 0, 1]. (also i add paddings for batch training but it's also masked in mask by zero)
  2. In model i add positional embeddings to the input sequence and than i build causal_mask for prediction. i do it this way: cusal_mask = torch.tril(mask.unsqueeze(-1) * mask.unsqueeze(-2)) == 0 enter image description here so for my example this mask is: enter image description here

It's not convinient to show the code, that's why i tried to explain it. If there are some assumptions why it's not working, thanks!

i don't know what to do

1

There are 1 best solutions below

1
On

I think you're confusing masked language modeling (MLM) with causal language modeling (CLM).

MLM is used for models like BERT. For MLM training, a percentage of input tokens are replaced with [MASK] tokens, and the model predicts the true value of the [MASK] tokens using the other tokens in the sequence.

CLM is used for autoregressive models like GPT. CLM doesn't mask input tokens (ie there is no [MASK] token at all), but instead uses masking in attention layers. Attention masking prevents tokens from attending to subsequent tokens, forcing each token to only use information from previous tokens.

For your use case of generating sequences, you want to use a CLM approach.

For inference, you necessarily need to predict one token at a time, but you can use KV caching to speed things up and reduce redundant computation.