GNN with Stable baselines

796 Views Asked by At

I am looking to use DGL or pytorch geometric for building my policy and value networks in stable baselines, however I am struggling to figure out how to send over observations. The observations must be one of the gym spaces class but I am not sure how to send a graph object that can be used by DGL or Pytorch geometric in this way.

The fundamental question I have is how to send graph observations and where to do the prepossessing necessary to use DGL or pytorch geometric for a custom stable baselines network? Can I pack the graph into a stable baselines observation space that somehow DGL or pytorch geometric could intake it?

Note: If anyone has a github link with any code that has done this please let me know, I have looked everywhere

1

There are 1 best solutions below

1
On

You can serialize your DGL graph object using pickle and convert the resultant byte string into a vector of integers (with each char in the string corresponding to one int).

import dgl
import numpy as np
import pickle

def serialize_graph(graph: dgl.DGLGraph):
    as_byte_string = pickle.dumps(graph)
    as_int_list = [_ for _ in as_byte_string]  # we get ints for free without explicitly casting
    as_float_array = np.array(as_int_list, dtype=np.float32)
    return as_float_array

You can then apply the same operations in reverse to deserialize vector representation of the graph within your custom feature extractor.

import dgl
import pickle
import torch as th

def deserialize_graph(observation: th.Tensor):
    as_int_tensor = observation.to(dtype=th.int32)
    as_char_list = [chr(_) for _ in observation]
    as_byte_string = bytearray(''.join(as_char_list), encoding='latin')
    as_dgl_graph = pickle.loads(as_byte_string)
    return as_dgl_graph