unable to implement tgt_mask and tgt_key_padding mask properly in transformer decoder model

49 Views Asked by At

I have implemented a transformer decoder for next token prediction. I am passing the tgt_mask and tgt_key_padding mask to not attend the future tokens and to ignore the padding. but I am constantly receiving the error

Training Epoch 1/1:   0%|                                                                                       | 0/563 [00:00<?, ?it/s]

tgt_emb torch.Size([16, 875, 256])
tgt_mask torch.Size([875, 875])
padding_mask torch.Size([16, 875])

/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/functional.py:5076: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.
  warnings.warn(
Training Epoch 1/1:   0%|                                                                                       | 0/563 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/scratch/harsha.vasamsetti/decoder_aug_made/main.py", line 146, in <module>
    output = model(input_batch)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/harsha.vasamsetti/decoder_aug_made/transformer.py", line 47, in forward
    output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 846, in forward
    x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 855, in _sa_block
    x = self.self_attn(x, x, x,
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/functional.py", line 5318, in multi_head_attention_forward
    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([875, 875]), but should be (16, 16).

I read the PyTorch documentation and understood the shape of tgt_mask should be (T,T) which is the lenght if the sequence and tgt_key_padding mask should be (batch_size,T) as you can see the shapes in the error I am exactly passing the same shape but I am receiving the error I dont know were I am doing wrong.

import torch.nn as nn
import torch
import math


# Define the Transformer model class
class TransformerModel(nn.Module):
    def __init__(self, vocab_size,pad_idx, n_embd, n_head, n_layers, max_length, dropout=0.1):
        super().__init__()
        self.pad_idx = pad_idx  # Add this line

        self.embed = nn.Embedding(vocab_size, n_embd)
        decoder_layer = nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.pos_encoder = PositionalEncoding(n_embd, dropout, max_length)
        self.n_embd = n_embd
        self.generator = nn.Linear(n_embd, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, tgt):
            tgt_emb = self.pos_encoder(self.embed(tgt) * math.sqrt(self.n_embd))  # (batch_size, seq_len, emb_dim)
            print("tgt_emb", tgt_emb.shape)
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)  # Adjusted to tgt.size(1) for seq_len
            print("tgt_mask", tgt_mask.shape)

            # Create padding mask based on EOS token used for padding
            if self.pad_idx is not None:
                padding_mask = (tgt == self.pad_idx)  # (batch_size, seq_len)
                print("padding_mask", padding_mask.shape)
            else:
                padding_mask = None

            memory = torch.zeros_like(tgt_emb)  # Simplified memory initialization
            output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask)
            return self.generator(output.transpose(0, 1))  # Adjust generator input if necessary

# Positional encoding class adds information about the order of tokens
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
0

There are 0 best solutions below