Models with multiple GraphTensors inputs (tfgnn, gnn)

35 Views Asked by At

I am trying to make a model that takes two different GraphTensors as input. Unfortunately, all the documentation is either about single GraphTensors or datasets containing one type of GraphTensor. My datasets are prepared as follows:

# This is how TFRecords are created (each graph type for train or val is written separately)
with tf.io.TFRecordWriter(api_train_path) as writer:
    for index in range(train_len):
        api_graph, sol_graph = GenTwoPairedGraphs(chemicals_df, dataset_df, atoms_df, bonds_df, index)
        example = tfgnn.write_example(api_graph)
        writer.write(example.SerializeToString())

## It would be best to combine everything here,, but I don't know how to decode it later. 

# Decode functions
def decode_fn_api(record_bytes):
    graph = tfgnn.parse_single_example(
      api_graph_tensor_spec, record_bytes, validate=True)

    # extract label from context and remove from input graph
    context_features = graph.context.get_features_dict()
    label = context_features.pop('label')
    api_graph = graph.replace_features(context=context_features)
    return api_graph, label

def decode_fn_sol(record_bytes):
    sol_graph = tfgnn.parse_single_example(
      solvent_graph_tensor_spec, record_bytes, validate=True)
    return sol_graph

def ImportGraphDataset():
    api_train_path = "TFRecords/api_train_dataset.tfrecords"
    api_val_path = "TFRecords/api_val_dataset.tfrecords"
    sol_train_path = "TFRecords/sol_train_dataset.tfrecords"
    sol_val_path = "TFRecords/sol_val_dataset.tfrecords"
    api_train_ds = tf.data.TFRecordDataset([api_train_path]).map(decode_fn_api)
    api_val_ds = tf.data.TFRecordDataset([api_val_path]).map(decode_fn_api)
    sol_train_ds = tf.data.TFRecordDataset([sol_train_path]).map(decode_fn_sol)
    sol_val_ds = tf.data.TFRecordDataset([sol_val_path]).map(decode_fn_sol)
    
    return api_train_ds, api_val_ds, sol_train_ds, sol_val_ds
api_train_ds, api_val_ds, sol_train_ds, sol_val_ds = ImportGraphDataset()

My model is bulid as follows:

# One part of the model with api_graph as input

api_input_graph = tf.keras.layers.Input(type_spec=api_graph_tensor_spec, name="api_graph")
api_graph = api_input_graph.merge_batch_to_components()
(...)
api_readout_features = tfgnn.keras.layers.Pool(
    tfgnn.CONTEXT, "mean", node_set_name="atoms")(api_graph)

# Second part of the model with sol_graph as input

solvent_input_graph = tf.keras.layers.Input(type_spec=solvent_graph_tensor_spec, name="solvent_graph")
solvent_graph = solvent_input_graph.merge_batch_to_components() 
(...)
solvent_readout_features = tfgnn.keras.layers.Pool(
        tfgnn.CONTEXT, "mean", node_set_name="atoms")(solvent_graph)

# Final layers that calculate output
feat = tf.concat([api_readout_features, solvent_readout_features], axis=1)
final_dense = tf.keras.layers.Dense(32, activation="relu")(feat)
logits = tf.keras.layers.Dense(1, name = "label")(final_dense)

tf.keras.Model(inputs=[api_input_graph, solvent_input_graph], outputs = [logits])

The problem is that model.fit does not accept more than one dataset as input, and I don't know how to place and decode a TFRecord with more than one type of GraphTensor.

So far I have tried using tf.data.Dataset.zip (input_ds = tf.data.Dataset.zip({"api_graph": api_train_ds, "solvent_graph": sol_train_ds}, label) when each set was separately, but I have not been able to run the model this way.

0

There are 0 best solutions below