How to do RL+graph neural network using stable-baselines3

450 Views Asked by At

I am new to stable-baselines3 and am trying to get a toy graph neural network problem to work. I previously had a bit flipping example using an array. The problem is this: given a list of 10 random bits and an operation which flips a bit find a way to flip bits to set them all to 1. Clearly you can do this by just flipping the bits that are currently 0 but the system has to learn this.

I would like to do the same thing where the input is simple linear graph with node weights instead of an array. I am not sure how to do this. The following code snippet will make a linear graph with 10 nodes, add node weights to each node and convert it to a dgl graph

import networkx as nx
import random
import dgl
# Create edges to add
edges = []
N = 10
for i in range(N-1):
edges.append((i, i+1))
# Create graph and convert it into a dgl graph
G=nx.DiGraph()
G.add_edges_from(edges)
for i in range(len(G.nodes)):
    G.nodes[i]['weight'] = random.choice([0,1])
    dgl_graph = dgl.from_networkx(G, node_attrs=["weight"])

When I was using a linear array for the bit flipping example my environment was this:

import numpy as np 
import gym from gym 
import spaces
class GraphFlipEnv(gym.Env):
def init(self, array_length=10): 
    super(BitFlipEnv, self).init()
    # Size of the 1D-grid
    self.array_length = array_length
    # Initialize the array of bits to be random
    self.agent_pos = random.choices([0,1], k=array_length)

    # Define action and observation space
    # They must be gym.spaces objects
    # Example when using discrete actions, we have two: left and right

    self.action_space = spaces.Discrete(array_length)
    # The observation will be the coordinate of the agent
    # this can be described both by Discrete and Box space
    self.observation_space = spaces.Box(low=0, high=1,
                                    shape=(array_length,), dtype=np.uint8)
def reset(self): # Initialize the array to have random values self.time = 0
    print(self.agent_pos)
    self.agent_pos = random.choices([0,1], k=self.array_length)
    return np.array(self.agent_pos)

def step(self, action): 
    self.time +=  1 
    if not 0 <= action < self.array_length: 
         raise ValueError("Received invalid action={} which is not part of the action space".format(action)) 
    self.agent_pos[action] ^= 1  # flip the bit
    if self.agent_pos[action] == 1:
        reward = 1
    else:
        reward = -1

    done = all(self.agent_pos)

    info = {}

    return np.array(self.agent_pos), reward, done, info
def render(self, mode='console'): 
    print(self.agent_pos)
def close(self): 
    pass

The last few lines to complete the code in the array version are simply:

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(lambda: BitFlipEnv(array_length=50), n_envs=12)
# Train the agent
model = PPO('MlpPolicy', env,  verbose=1).learn(500000)

I can't use stable-baselines' spaces any more for the graph so what is the right way to get stable-baselines to interface with my dgl graph for this toy problem?

0

There are 0 best solutions below