I am implementing a modified vision transformer based on the Github implementation. The author has also published a YouTube video explaining the implementation. But this implementation doesn't have any provision to incorporate src_key_padding_mask
. (The built-in transformer encoder accepts this as a parameter) I know that I have to perform some mathematical operation using this mask in the forward
method of the Attention
module. The mask contains True
where there is a padding token and False
elsewhere.
If I just place dp = dp @ mask
just after the dot product of query and key, will it serve the purpose of src_key_padding_mask
used in the built-in version?
dp = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
dp = dp @ mask