What do the attention weights returned by torch_geometric.nn.conv.GATConv represent?

38 Views Asked by At
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads, dropout=0.3):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, dropout=dropout)
        self.conv3 = GATConv(hidden_channels * num_heads, out_channels, heads=1, concat = False, dropout=dropout)

    def forward(self, x, edge_index):
        x, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x, attention_weights2 = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x, attention_weights3 = self.conv3(x, edge_index, return_attention_weights=True)

        #Note: When using CrossEntropyLoss, the softmax function is included in the loss function
        #out = self.softmax(x)
        return x, attention_weights1, attention_weights2, attention_weights3

Using this implementation, we trained a GAT on an undirected graph to perform binary classification. Then we extracted the attention weights from the last layer (attention_weights3 in the code). According to the paper on which this implementation was based (https://arxiv.org/abs/1710.10903), for every pair of nodes (i,j) there should be an attention weight from node i to node j, and one attention weight from node j to node i. However, for each edge, we receive only one attention weight. Have we misunderstood what the attention wieghts in this implementation is?

We tried extracting the attention weights for the last layer but it gives only one weight per edge. Furthermore, when selecting one node and adding up the attention weights of all edges conntected to that node, the values do not add up to 1. If the attention weights were the normalized version described in the paper, then the attention weights for a particular node should add up to 1.

0

There are 0 best solutions below