This is what my graph looks like:
cust_prod_graph = Graph(num_nodes={'customer': 8813, 'product': 157466},
num_edges={('customer', 'browsed', 'product'): 860771, ('customer', 'purchased', 'product'): 68367},
metagraph=[('customer', 'product', 'browsed'), ('customer', 'product', 'purchased')])
Customer nodes have 932 features, and product nodes have 5641 features.
in_feats_dic = {'product': tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.6545e-03,
5.2197e-04, 9.4348e-04],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -4.3188e-03,
-8.2494e-03, -2.5112e-04],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -3.3462e-04,
-1.2548e-03, 8.8542e-04],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -7.1546e-04,
9.2454e-04, -1.4747e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.1572e-05,
-5.2803e-04, -4.2493e-04],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.4136e-02,
-4.1191e-03, 7.5153e-03]], dtype=torch.float64),
'customer': tensor([[ 3., 1., 1., ..., 0., 0., 0.],
[15., 5., 4., ..., 0., 0., 0.],
[90., 14., 11., ..., 0., 0., 0.],
...,
[14., 2., 2., ..., 0., 0., 0.],
[ 3., 2., 2., ..., 0., 0., 0.],
[ 2., 2., 2., ..., 0., 0., 0.]], dtype=torch.float64)}
I am defining the model in the following fashion:
hid_feats_customer=20
hid_feats_product=10
rel_names = cust_prod_graph.etypes
x = nn.ModuleDict({
node_type: dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats_dict[node_type].shape[1], hid_feats_customer if node_type == 'customer' else hid_feats_product).double()
for rel in rel_names
}, aggregate='sum')
for node_type in in_feats_dict
})
The model, x, looks like this:
ModuleDict(
(product): HeteroGraphConv(
(mods): ModuleDict(
(browsed): GraphConv(in=5641, out=10, normalization=both, activation=None)
(purchased): GraphConv(in=5641, out=10, normalization=both, activation=None)
)
)
(customer): HeteroGraphConv(
(mods): ModuleDict(
(browsed): GraphConv(in=932, out=20, normalization=both, activation=None)
(purchased): GraphConv(in=932, out=20, normalization=both, activation=None)
)
)
)
Now, here is my problem.
When I run
x['customer'](cust_prod_graph,in_feats_dict)
, I expect the embeddings of the customer nodes (meaning, I expect an output tensor of shape (num_customers, hid_features_customers) ).
However, my output tensor is of the shape (num_products, hid_features_products), which is super weird. I understand that message passing involves aggregating features from the neighbours. So, does this mean my output represents the embeddings generated for the products? I am thoroughly confused.
What baffles me even more is that when I run
x['product'](cust_prod_graph,in_feats_dict)
, I get the following error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (8813x932 and 5641x10) .
Any sort of explanation/help is super appreciated!