GNN adaptability to different number of node as input graph

279 Views Asked by At

I am looking at methods for GNN to process GNN based on entity graph generated on objects on the image. Node represent each object and edge is connected between two nodes if two objects' Euclidean distance is lower than a threshold. Let's say the Graph is represented as G=(V,E). But for every graph there might be different number of object detected and different number of connections. How can a GNN be able to learn to process different number of node(with different number of neighbour) as input, and during inference how can a model be able to use the model learn to obtain a result? How would the weights shape look like?

I have tried thinking about ways to pool or limit number of nodes in a graph, but that does not seem like a smart way to do it. I am pretty struggling with this at the moment. Would really appreciate it if there is any help.

Thanks.

1

There are 1 best solutions below

0
On

what is the thing you are trying to predict? There are different approaches for node level Classification and graph level classification. Have you checked torch library? https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html#torch_geometric.nn.conv.GCNConv It is okay to have varying number of nodes if it’s a graph level prediction. Also, edge weights and attributes can be handled easily by torch layer. You can define them in the argument.

Here is an example of architecture


class GCNAttentionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate):
        
        #super(GCNModel, self).__init__()
        super(GCNAttentionModel, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim)
        self.conv2 = GATConv(hidden_dim, output_dim)

        self.dropout_rate = dropout_rate
        
        
    def forward(self, x, edge_index, edge_attr,batch):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.nn.LeakyReLU()(x)
        #x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.nn.LeakyReLU()(x)
        #x = F.dropout(x, p=self.dropout_rate, training=self.training)
        return x