I am learning basic ideas about the 'Transformer' Model. Based on the paper and tutorial I saw, the 'Attention layer' uses the neural network to get the 'value', the 'key', and the 'query'.
Here is the attention layer I learned from online.
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
#split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
energy = torch.einsum("nqhd, nkhd->nhqk", [queries, keys])
# queires shape (N, quesry_len, heads, heads_dim)
# kyes shape : (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum('nhql, nlhd->nqhd', [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
# attention shape: (N, heads, query_len, key_len)
# value shpae: (N, value_len, heads, heads_dim)
# after eisum (N, query_len, heads, head_dim) then flatten the last two dimmsions
out = self.fc_out(out)
return out
One thing I was very confused is why do we need the 'key', 'value', and 'query'? Could I just use only one of them? Or could I set more values besides these three? It looks like these three values are just transformed by three single-layer neural networks.