I have implemented a transformer decoder for next token prediction. I am passing the tgt_mask and tgt_key_padding mask to not attend the future tokens and to ignore the padding. but I am constantly receiving the error
Training Epoch 1/1: 0%| | 0/563 [00:00<?, ?it/s]
tgt_emb torch.Size([16, 875, 256])
tgt_mask torch.Size([875, 875])
padding_mask torch.Size([16, 875])
/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/functional.py:5076: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.
warnings.warn(
Training Epoch 1/1: 0%| | 0/563 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/scratch/harsha.vasamsetti/decoder_aug_made/main.py", line 146, in <module>
output = model(input_batch)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/harsha.vasamsetti/decoder_aug_made/transformer.py", line 47, in forward
output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 460, in forward
output = mod(output, memory, tgt_mask=tgt_mask,
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 846, in forward
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 855, in _sa_block
x = self.self_attn(x, x, x,
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1241, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/home2/harsha.vasamsetti/miniconda3/envs/slices/lib/python3.9/site-packages/torch/nn/functional.py", line 5318, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([875, 875]), but should be (16, 16).
I read the PyTorch documentation and understood the shape of tgt_mask should be (T,T) which is the lenght if the sequence and tgt_key_padding mask should be (batch_size,T) as you can see the shapes in the error I am exactly passing the same shape but I am receiving the error I dont know were I am doing wrong.
import torch.nn as nn
import torch
import math
# Define the Transformer model class
class TransformerModel(nn.Module):
def __init__(self, vocab_size,pad_idx, n_embd, n_head, n_layers, max_length, dropout=0.1):
super().__init__()
self.pad_idx = pad_idx # Add this line
self.embed = nn.Embedding(vocab_size, n_embd)
decoder_layer = nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head, dropout=dropout)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
self.pos_encoder = PositionalEncoding(n_embd, dropout, max_length)
self.n_embd = n_embd
self.generator = nn.Linear(n_embd, vocab_size)
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, tgt):
tgt_emb = self.pos_encoder(self.embed(tgt) * math.sqrt(self.n_embd)) # (batch_size, seq_len, emb_dim)
print("tgt_emb", tgt_emb.shape)
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) # Adjusted to tgt.size(1) for seq_len
print("tgt_mask", tgt_mask.shape)
# Create padding mask based on EOS token used for padding
if self.pad_idx is not None:
padding_mask = (tgt == self.pad_idx) # (batch_size, seq_len)
print("padding_mask", padding_mask.shape)
else:
padding_mask = None
memory = torch.zeros_like(tgt_emb) # Simplified memory initialization
output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=padding_mask)
return self.generator(output.transpose(0, 1)) # Adjust generator input if necessary
# Positional encoding class adds information about the order of tokens
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)