How to implement src_key_padding_mask in a vision transformer

94 Views Asked by At

I am implementing a modified vision transformer based on the Github implementation. The author has also published a YouTube video explaining the implementation. But this implementation doesn't have any provision to incorporate src_key_padding_mask. (The built-in transformer encoder accepts this as a parameter) I know that I have to perform some mathematical operation using this mask in the forward method of the Attention module. The mask contains True where there is a padding token and False elsewhere.

If I just place dp = dp @ mask just after the dot product of query and key, will it serve the purpose of src_key_padding_mask used in the built-in version?

dp = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
dp = dp @ mask
0

There are 0 best solutions below