GNN/Multi Agent RL in RayRLLib and PyTorch Geometric

212 Views Asked by At

Currently trying to add a custom_model on ray rllib, for a GNN using PyTorch. Using multi agent reinforcement learning (Specifically PPO).

The scenario is inventory management where I model each node in my supply chain as an agent. I define a connections dictionary e.g. {0: [1], 1: [2], 2: []} where it denotes 0 --> 1 --> 2

I then proceed to create my network using

def create_network(connections):
    num_nodes = max(connections.keys())
    network = np.zeros((num_nodes + 1, num_nodes + 1))
    for parent, children in connections.items():
        if children:
            for child in children:
                network[parent][child] = 1

In my init in the environment, I initialise my environment with graph = nx.from_numpy_array(self.network)

I then define another function graph_from_state (See below) as follows where I convert my graph into a pytorch geometric data structure. For simplicity, I haven't added my obs vector within my node_features dictionary (but that is the plan) and as such I have define my observation within a graph.

The plan was to use this data.x and data.edge_index in my custom GNN model. However, I have realised that I can't do this due to my SampleBatch being done automatically and don't being able to change this. I also tried adding these two variables into my infos dictionary but SampleBatch automatically converts dictionary into tensor.

When I fix both my x and edge_index (just to help me debug), I end up in an issue where the Sample Batch has different sizes for v_predt and rewards? Why would this happen and how would I fix this? The error specifically was: delta_t = rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1] ValueError: operands could not be broadcast together with shapes (32,) (12,)

***I have tried many things and failing - my question is how do I go about integrating my GNN model with multi agent in ray.rllib? Does anyone have any examples?


def graph_from_state(self):

        graph1 = self.graph.copy()
        
        node_features = {i : {"node_price": self.node_price[i], \
                              "node_cost": self.node_cost[i], \
                            "order_max": self.order_max[i],\
                                "inv_max": self.inv_max[i], 
                                }
                    for i in range(self.num_nodes)}

        print(node_features)   

        nx.set_node_attributes(graph1, node_features)

        for node in range(self.num_nodes):
            print(graph1.nodes[node])

        self.network = create_network(self.connections)
        print(graph1.nodes())
        print(node_features)
        self.data = from_networkx(graph1)

        node_feature = torch.stack([
            self.data.node_price,self.data.node_cost, self.data.order_max, \
                self.data.inv_max],
            dim=1
        )

        self.data = Data(x = node_feature, edge_index = self.data.edge_index)  
        self.edge_index = self.data.edge_index
        self.x = self.data.x

0

There are 0 best solutions below