How to Visualize the model graph of a Graph Neural Network in Tensorboard

3.4k Views Asked by At

I am trying to visualize the computation graphs of Graph Neural Networks I make to predict properties of Molecules. The model is made in PyTorch and takes as input DGL graphs. The code snippet for trying to visualize the model looks like this:

train_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/train'
train_summary_writer = tensorboardX.SummaryWriter(train_log_dir)
train_summary_writer.add_graph(model, [transformer(dataset[0][0]), transformer(dataset[0][0])])

I encounter the following error, TensorBoardX fails to visualize the graph model, refuses to accept DGL graphs as inputs and only wants tensors. Is there any way I can visualize the model?

RuntimeError: Tracer cannot infer type of (Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'x': Scheme(shape=(10,), dtype=torch.float32)}
      edata_schemes={'w': Scheme(shape=(4,), dtype=torch.float32)}), Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'x': Scheme(shape=(10,), dtype=torch.float32)}
      edata_schemes={'w': Scheme(shape=(4,), dtype=torch.float32)}))
:Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type DGLHeteroGraph.

Process finished with exit code 1
1

There are 1 best solutions below

2
On

I usually use SummaryWriter from torch library. It works somehow like this:

...
from torch.utils.tensorboard import SummaryWriter
...

# initializing your model

model = ...
dummy_input = ...

...
writer = SummaryWriter(f'logs/net')
writer.add_graph(model, dummy_input)

and then after running your python script in terminal run:

tensorboard --logdir logs

then it throw link something like localhost:6006 and there will be your visualized graph model. For more info: https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html