How can I convert a multi-head attention layer from Tensorflow to Pytorch where key_dim * num_heads != embed_dim?

162 Views Asked by At

I am trying to implement a Pytorch version of some code that was previously written in Tensorflow. In the code I am starting with, there exists a multi-head attention layer that is instantiated in the following way:

encoder_input = tf.keras.layers.Input(shape=(5, 3))
xl = tf.keras.layers.MultiHeadAttention(num_heads=10,key_dim=2)(encoder_input,encoder_input)

The input and output shape going into and out of this layer is (None, 5, 3).

As I began rewriting the code, I realized that this would not work in Pytorch, as the embedding dimension is 3, but pytorch expects key_dim * num_heads == embed_dim. The Tensorflow version somehow uses 'key_dim * num_heads' as the size of the weight matrix no matter what the embedding dimension is, and yet the output depth is always 'embed_dim'.

The code I am rewriting is functional, and I'm not interested in modifying what it is doing at least until I have gotten a pytorch version working. So what I really need to know is

  1. What is the tensorflow version of the multihead attention layer doing in this case where the number of heads and key dimension are less than the embedding dimension?

  2. How can I replicate whatever it is doing in pytorch, so that my version of the code will work?

My initial idea for how to create a pytorch version of the code was to put a linear layer before the multi-head attention layer as shown below which would increase the embedding dimension to 20 so I could keep the key dimension and number of heads the same, and then use another linear layer afterwards to reduce it back down, but I'm pretty sure that wouldn't be equivalent to whatever tensorflow is doing, or indeed if it would not be reasonable.

class Network(nn.Module):
    def __init__(self, emb_dim, heads, n_feature):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=20,num_heads=10)#key_dim = 2
        self.up = nn.Linear(n_feature,20)
        self.down = nn.Linear(20,n_feature)
        self.feed_forward = nn.Sequential(
            nn.Linear(n_feature,100),#input features num might be wrong. look at mha output to find out what it is
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(100, n_feature),#input features might not be correct...
        )

    def forward(self,x):
        x1 = self.up(x)
        x1 = self.attn(x1,x1,x1)
        x1 = self.down(x1[0])
        res = x1 + x
        out = self.feed_forward(res)
        return out
0

There are 0 best solutions below