This is what my graph looks like:

cust_prod_graph = Graph(num_nodes={'customer': 8813, 'product': 157466},
      num_edges={('customer', 'browsed', 'product'): 860771, ('customer', 'purchased', 'product'): 68367},
      metagraph=[('customer', 'product', 'browsed'), ('customer', 'product', 'purchased')])

Customer nodes have 932 features, and product nodes have 5641 features. 

in_feats_dic = {'product': tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.6545e-03,
           5.2197e-04,  9.4348e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -4.3188e-03,
          -8.2494e-03, -2.5112e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.3462e-04,
          -1.2548e-03,  8.8542e-04],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.1546e-04,
           9.2454e-04, -1.4747e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -2.1572e-05,
          -5.2803e-04, -4.2493e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.4136e-02,
          -4.1191e-03,  7.5153e-03]], dtype=torch.float64),
 'customer': tensor([[ 3.,  1.,  1.,  ...,  0.,  0.,  0.],
         [15.,  5.,  4.,  ...,  0.,  0.,  0.],
         [90., 14., 11.,  ...,  0.,  0.,  0.],
         ...,
         [14.,  2.,  2.,  ...,  0.,  0.,  0.],
         [ 3.,  2.,  2.,  ...,  0.,  0.,  0.],
         [ 2.,  2.,  2.,  ...,  0.,  0.,  0.]], dtype=torch.float64)}

I am defining the model in the following fashion:

hid_feats_customer=20
hid_feats_product=10
rel_names = cust_prod_graph.etypes

x = nn.ModuleDict({
    node_type: dglnn.HeteroGraphConv({
        rel: dglnn.GraphConv(in_feats_dict[node_type].shape[1], hid_feats_customer if node_type == 'customer' else hid_feats_product).double()
        for rel in rel_names
    }, aggregate='sum')
    for node_type in in_feats_dict
})

The model, x, looks like this:

ModuleDict(
  (product): HeteroGraphConv(
    (mods): ModuleDict(
      (browsed): GraphConv(in=5641, out=10, normalization=both, activation=None)
      (purchased): GraphConv(in=5641, out=10, normalization=both, activation=None)
    )
  )
  (customer): HeteroGraphConv(
    (mods): ModuleDict(
      (browsed): GraphConv(in=932, out=20, normalization=both, activation=None)
      (purchased): GraphConv(in=932, out=20, normalization=both, activation=None)
    )
  )
)

Now, here is my problem.

When I run

x['customer'](cust_prod_graph,in_feats_dict) 

, I expect the embeddings of the customer nodes (meaning, I expect an output tensor of shape (num_customers, hid_features_customers) ).

However, my output tensor is of the shape (num_products, hid_features_products), which is super weird. I understand that message passing involves aggregating features from the neighbours. So, does this mean my output represents the embeddings generated for the products? I am thoroughly confused.

What baffles me even more is that when I run

x['product'](cust_prod_graph,in_feats_dict) 

, I get the following error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (8813x932 and 5641x10) .

Any sort of explanation/help is super appreciated!

0

There are 0 best solutions below