Nan output after masked TransforrmerDecoder

64 Views Asked by At

What is wrong with my mask, or why it doesn't work? I try to do predict mask token based only on the first token and masked token. For this i created special create_casual_mask method that creates this mask for multihead attention. When i run it, nan tensor is returned. Pad mask and mask of masked tokens are not intersected, i full my attention mask as it is described in torch documentations for torch.nn.TransformerDecoder. Also output of attention shouldn't be empty because there are some False values in attention mask. So why it doesn't work?

import torch
from torch import nn

torch.manual_seed(0)


class LookOnFirstDecoder(nn.Module):
    def __init__(self, depth, d_model, nhead, d_ff,
                 dropout, activation,
                sent_length, n_tokens, pad_idx
    ):
        super().__init__()
        """
        :param sent_length: max length of sentence
        :param n_tokens: number of tokens to use including mask and padding tokens
        :param pad_idx: index of padding to don't compute the gradient
        """
        self.d_model = d_model
        self.nhead = nhead
        self.n_tokens = n_tokens
        self.emb = nn.Embedding(
            num_embeddings=n_tokens,
            embedding_dim=d_model,
            padding_idx=pad_idx
        )

        self.pos_embed = nn.Parameter(
            torch.zeros(1, sent_length, d_model),
            requires_grad=True
        )
        torch.nn.init.normal_(self.pos_embed, std=.02)

        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_ff,
                dropout=dropout,
                activation=activation,
                batch_first=True,
                norm_first=True,
            ),
            num_layers=depth,
        )

        self.fin_lin = nn.Linear(d_model, n_tokens)

    def create_causal_mask(self, mask):
        """
            The purpose is to create mask that allows all not first tokens
        look only on the first token and itself
        :param mask: (B, L)
        :return: (B * nhead, L)
        """

        mask[:, 0] = True  # to depend on first token
        b, l = mask.shape
        batch_causal_mask = ~torch.tril(mask.unsqueeze(-1) * mask.unsqueeze(-2))  # (B, L, L)
        # batch_causal_mask = torch.tril(torch.ones((b, l, l))).to("cuda") == 0

        # batch_causal_mask = torch.where(batch_causal_mask, 0, float('-inf'))
        print(f"Batch causal mask: \n{batch_causal_mask}")

        causal_mask = (
            batch_causal_mask.
            unsqueeze(1).  # (B, 1, L, L)
            expand(b, self.nhead, l, l).  # (B, nhead, L, L)
            reshape(b * self.nhead, l, l)  # (B * nhead, L, L)
        )

        return causal_mask

    def forward(self, tgt, memory, is_masked_mask, is_pad_mask):
        """
        :param tgt: (B, L)
        :param memory: (B, L1, D)
        :param is_masked_mask: (B, L) - True - mask token, False - not
        :param is_pad_mask: (B, L), True - pad token, False - not
        :return: tensor of shape (B, n_tokens)
        """
        b, l = tgt.shape
        tgt_tokens = self.emb(tgt) + self.pos_embed[:, :l].expand(b, l, self.d_model)

        tgt_tokens = self.transformer(
            tgt_tokens,
            memory,
            tgt_mask=self.create_causal_mask(is_masked_mask.clone()),
            tgt_is_causal=True,
            tgt_key_padding_mask=is_pad_mask
        )  # (B, L, D)

        fin_tokens = self.fin_lin(tgt_tokens[is_masked_mask])
        return fin_tokens


# my vocabulary
n_tokens = 10  # pad_idx - 9, mask_idx - 8
pad_idx = n_tokens - 1
mask_idx = n_tokens - 2

d_model = 4
nhead = 2
b, l = 3, 8

model = LookOnFirstDecoder(
    depth=2,
    d_model=4,
    nhead=2,
    d_ff=8,
    dropout=0.1,
    activation="gelu",
    sent_length=l,
    n_tokens=n_tokens,
    pad_idx=pad_idx
)

memory = torch.randn(b, l, d_model)

# so i create some random tokens, without padding and mask
in_tokens = torch.randint(0, mask_idx - 1, (b, l))

# mask and paddings add manually
in_tokens[0, 6:] = pad_idx
in_tokens[0, 5] = mask_idx

in_tokens[1, 7:] = pad_idx
in_tokens[1, 4] = mask_idx

in_tokens[2, 5:] = pad_idx
in_tokens[2, 0] = mask_idx

is_masked_mask = in_tokens == mask_idx
is_pad_mask = in_tokens == pad_idx

pred = model(in_tokens, memory, is_masked_mask, in_tokens == pad_idx)

print(f"In tokens: \n{in_tokens}")
print(f"Pad mask: \n{is_pad_mask}")
print(f"Masked mask: \n{is_masked_mask}")
print(f"Pred: \n{pred}")

these is my requirements.txt

torch == 2.1.1
torchvision == 0.16.1
xformers
albumentations==1.3.1

numpy == 1.26.2
scipy == 1.11.4
scikit-learn == 1.3.2
pandas == 2.1.4
matplotlib == 3.8.2
seaborn == 0.13.0

That is my result after execution:

Batch causal mask: 
tensor([[[False,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [False,  True,  True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True]],

        [[False,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [False,  True,  True,  True, False,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True]],

        [[False,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True]]])
In tokens: 
tensor([[2, 1, 4, 0, 3, 8, 9, 9],
        [6, 4, 0, 6, 8, 0, 5, 9],
        [8, 2, 5, 2, 6, 9, 9, 9]])
Pad mask: 
tensor([[False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False,  True,  True,  True]])
Masked mask: 
tensor([[False, False, False, False, False,  True, False, False],
        [False, False, False, False,  True, False, False, False],
        [ True, False, False, False, False, False, False, False]])
Pred: 
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       grad_fn=<AddmmBackward0>)
0

There are 0 best solutions below