Graph tensor
Import modules
[1]:
from molgraph import chemistry
from molgraph import layers
from molgraph import GraphTensor #####
import tensorflow as tf
from tensorflow import keras
Construct a GraphTensor
Although a GraphTensor can be constructed directly from its constructor, here we construct a GraphTensor from a MolecularGraphEncoder.
[2]:
atom_encoder = chemistry.Featurizer([
chemistry.features.Symbol({'C', 'N', 'O'}, oov_size=1),
chemistry.features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),
chemistry.features.HydrogenDonor(),
chemistry.features.HydrogenAcceptor(),
chemistry.features.Hetero()
])
bond_encoder = chemistry.Featurizer([
chemistry.features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),
chemistry.features.Rotatable()
])
mol_encoder = chemistry.MolecularGraphEncoder(
atom_encoder, bond_encoder, positional_encoding_dim=None)
smiles_list = [
'OCC1OC(C(C1O)O)n1cnc2c1ncnc2N',
'C(C(=O)O)N',
'C1=CC(=CC=C1CC(C(=O)O)N)O'
]
graph_tensor = mol_encoder(smiles_list)
print(graph_tensor)
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)
.separate() – Separate subgraphs of GraphTensor
[3]:
graph_tensor = graph_tensor.separate()
print(graph_tensor)
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32, ragged_rank=1>,
edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
edge_feature=<tf.RaggedTensor: shape=(3, None, 5), dtype=float32, ragged_rank=1>)
.merge() – Merge subgraphs of GraphTensor
[4]:
graph_tensor = graph_tensor.merge()
print(graph_tensor)
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)
.propagate() – Propagate node informaton with the GraphTensor
[5]:
print('Node features before:\n', graph_tensor.node_feature, end='\n\n')
graph_tensor = graph_tensor.propagate()
print('Node features after:\n', graph_tensor.node_feature)
Node features before:
tf.Tensor(
[[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
[0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]
[0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
[0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]
[0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
[0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1.]], shape=(37, 11), dtype=float32)
Node features after:
tf.Tensor(
[[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 1. 0. 0. 0. 2. 1. 1. 1.]
[0. 2. 0. 1. 0. 0. 0. 3. 0. 1. 1.]
[0. 2. 0. 0. 0. 0. 0. 2. 0. 0. 0.]
[0. 1. 1. 1. 0. 0. 1. 2. 0. 2. 2.]
[0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]
[0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]
[0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 2. 1. 0. 0. 0. 3. 0. 0. 1. 1.]
[0. 1. 2. 0. 0. 0. 3. 0. 0. 2. 2.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 1. 2. 0. 0. 0. 3. 0. 1. 2. 2.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1.]
[0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 2. 0. 1. 0. 0. 3. 0. 1. 1. 1.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
[0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]
[0. 2. 0. 0. 0. 0. 1. 1. 0. 0. 0.]
[0. 2. 1. 0. 0. 0. 1. 2. 1. 1. 1.]
[0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]], shape=(37, 11), dtype=float32)
.update() – Update data of the GraphTensor
[6]:
node_supplementary_data = tf.random.uniform(
shape=graph_tensor.node_feature.shape[:-1] + [1])
node_feature_updated = tf.random.uniform(
shape=graph_tensor.node_feature.shape[:-1] + [128])
# Both add new data and update existing data of the GraphTensor:
graph_tensor = graph_tensor.update({
'node_supplementary_data': node_supplementary_data,
'node_feature': node_feature_updated
})
print(graph_tensor)
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>,
node_supplementary_data=<tf.Tensor: shape=(37, 1), dtype=float32>)
.remove() – Remove data from GraphTensor
[7]:
graph_tensor = graph_tensor.remove([
'node_supplementary_data', 'edge_feature'
])
print(graph_tensor)
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)
__getitem__ – Index lookup on the GraphTensor
The GraphTensor can be indexed either by passing a str (to obtain a specific field of GraphTensor) or int, list[int] or slice (to extract specific subgraphs (molecules) from GraphTensor). (Alternatively, the GraphTensor can be passed to tf.gather to extract specific subgraphs.)
[8]:
print("Complete graph:\n")
print("---" * 20)
print(graph_tensor, end='\n\n')
print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[[1, 2]], end='\n\n')
print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[:2], end='\n\n')
Complete graph:
------------------------------------------------------------
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)
------------------------------------------------------------
Subgraph (2) and (3) of graph:
GraphTensor(
sizes=<tf.Tensor: shape=(2,), dtype=int64>,
node_feature=<tf.Tensor: shape=(18, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(34,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(34,), dtype=int32>)
------------------------------------------------------------
Subgraph (2) and (3) of graph:
GraphTensor(
sizes=<tf.Tensor: shape=(2,), dtype=int64>,
node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)
__getattr__ – Attribute lookup on the GraphTensor
[9]:
print("Access `node_feature` field:\n")
print("---" * 20)
print(graph_tensor.node_feature, end='\n\n')
print("---" * 20)
print("Access `edge_src` field:\n")
print(graph_tensor.edge_src, end='\n\n')
print("---" * 20)
print("Access `graph_indicator` field:\n")
print(graph_tensor.graph_indicator, end='\n\n')
Access `node_feature` field:
------------------------------------------------------------
tf.Tensor(
[[0.30606592 0.01332998 0.28550065 ... 0.30522108 0.43709052 0.2496804 ]
[0.47505558 0.6802629 0.12628877 ... 0.54731417 0.85908985 0.01080072]
[0.32505012 0.16541815 0.9268564 ... 0.19977057 0.6975106 0.63107324]
...
[0.06981373 0.0497787 0.7329197 ... 0.72168195 0.992267 0.4002931 ]
[0.6254629 0.77454865 0.4750824 ... 0.21217322 0.10769343 0.71567035]
[0.29524624 0.7836231 0.7198993 ... 0.94255567 0.926514 0.62505746]], shape=(37, 128), dtype=float32)
------------------------------------------------------------
Access `edge_src` field:
tf.Tensor(
[ 0 1 1 2 2 2 3 3 4 4 4 5 5 5 6 6 6 7 8 9 9 9 10 10
11 11 12 12 12 13 13 13 14 14 15 15 16 16 17 17 17 18 19 19 20 20 20 21
22 23 24 24 25 25 26 26 26 27 27 28 28 29 29 29 30 30 31 31 31 32 32 32
33 34 35 36], shape=(76,), dtype=int32)
------------------------------------------------------------
Access `graph_indicator` field:
tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2], shape=(37,), dtype=int64)
tf.concat – Concatenating multiple GraphTensor instances
[10]:
print("Concatenating two graphs in non-ragged states:\n")
graph_tensor_concat = tf.concat([
graph_tensor,
graph_tensor], axis=0)
print(graph_tensor_concat, end='\n\n')
print("Inspect `graph_indicator`:\n")
print(graph_tensor_concat.graph_indicator, end='\n\n')
print('---' * 20)
print("Concatenating two graphs in ragged states")
graph_tensor_concat = tf.concat([
graph_tensor.separate(),
graph_tensor.separate()], axis=0)
print(graph_tensor_concat, end='\n\n')
Concatenating two graphs in non-ragged states:
GraphTensor(
sizes=<tf.Tensor: shape=(6,), dtype=int64>,
node_feature=<tf.Tensor: shape=(74, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(152,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(152,), dtype=int32>)
Inspect `graph_indicator`:
tf.Tensor(
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2
3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5], shape=(74,), dtype=int64)
------------------------------------------------------------
Concatenating two graphs in ragged states
GraphTensor(
sizes=<tf.Tensor: shape=(6,), dtype=int64>,
node_feature=<tf.RaggedTensor: shape=(6, None, 128), dtype=float32, ragged_rank=1>,
edge_src=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>,
edge_dst=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>)
tf.split – Splits a GraphTensor into multiple GraphTensor instances
[11]:
tf.split(graph_tensor_concat.merge(), num_or_size_splits=6)
[11]:
[GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(5, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(8,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(8,), dtype=int32>),
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(26,), dtype=int32>),
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(5, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(8,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(8,), dtype=int32>),
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(26,), dtype=int32>)]
.spec – The spec of the GraphTensor
[12]:
print(graph_tensor.spec)
GraphTensor.Spec(sizes=TensorSpec(shape=(None,), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})
.shape – Partial shape of the GraphTensor
[13]:
print('(partial) shape:', graph_tensor.shape)
(partial) shape: (3, None, 128)
.dtype – Partial dtype of the GraphTensor
[14]:
print('(partial) dtype:', graph_tensor.dtype.name)
(partial) dtype: float32
.rank – Partial rank of the GraphTensor
[15]:
print('(partial) rank: ', graph_tensor.rank)
(partial) rank: 3
tf.data.Dataset – Creating a TF dataset from a GraphTensor
[16]:
ds = tf.data.Dataset.from_tensor_slices(graph_tensor)
print('Dataset object:\n', ds)
print('\n' + '---' * 20)
# Loop over dataset
for i, x in enumerate(ds.batch(2).map(lambda x: x)):
print(f"\nbatch {i + 1}:\n")
print(x)
print('\n' + '---' * 20)
Dataset object:
<_TensorSliceDataset element_spec=GraphTensor.Spec(sizes=TensorSpec(shape=(), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})>
------------------------------------------------------------
batch 1:
GraphTensor(
sizes=<tf.Tensor: shape=(2,), dtype=int64>,
node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)
------------------------------------------------------------
batch 2:
GraphTensor(
sizes=<tf.Tensor: shape=(1,), dtype=int64>,
node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(26,), dtype=int32>)
------------------------------------------------------------
layers – Passing a GraphTensor to a layer
The GraphTensor can be passed to GNN layers either as a single disjoint graph or subgraphs.
[17]:
gin_conv = layers.GINConv(128)
print("Pass GraphTensor in non-ragged state:\n")
print(gin_conv(graph_tensor), end='\n\n')
print('---' * 20)
print('\nPass GraphTensor in ragged state:\n')
print(gin_conv(graph_tensor.separate()), end='\n\n')
Pass GraphTensor in non-ragged state:
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)
------------------------------------------------------------
Pass GraphTensor in ragged state:
GraphTensor(
sizes=<tf.Tensor: shape=(3,), dtype=int64>,
node_feature=<tf.RaggedTensor: shape=(3, None, 128), dtype=float32, ragged_rank=1>,
edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>)
Model – Passing a GraphTensor to a model
[18]:
model = tf.keras.Sequential([
layers.GCNConv(),
layers.GCNConv(),
layers.Readout(),
keras.layers.Dense(1)
])
y_dummy = tf.constant([[1.], [2.], [3.]])
model.compile('adam', 'huber')
print("Using (graph_tensor, label) pair as input:\n")
model.fit(graph_tensor, y_dummy, batch_size=2, epochs=5)
print('\n------------------------------\n')
print("Using tf.data.Dataset as input:\n")
dataset = tf.data.Dataset.from_tensor_slices((graph_tensor, y_dummy))
model.fit(dataset.batch(2), epochs=5);
Using (graph_tensor, label) pair as input:
Epoch 1/5
2/2 [==============================] - 3s 8ms/step - loss: 0.3226
Epoch 2/5
2/2 [==============================] - 0s 7ms/step - loss: 8.0285
Epoch 3/5
2/2 [==============================] - 0s 8ms/step - loss: 4.7673
Epoch 4/5
2/2 [==============================] - 0s 6ms/step - loss: 1.8421
Epoch 5/5
2/2 [==============================] - 0s 6ms/step - loss: 0.1327
------------------------------
Using tf.data.Dataset as input:
Epoch 1/5
2/2 [==============================] - 0s 7ms/step - loss: 0.8891
Epoch 2/5
2/2 [==============================] - 0s 9ms/step - loss: 1.2614
Epoch 3/5
2/2 [==============================] - 0s 6ms/step - loss: 1.0532
Epoch 4/5
2/2 [==============================] - 0s 6ms/step - loss: 0.6833
Epoch 5/5
2/2 [==============================] - 0s 6ms/step - loss: 0.4317
[ ]: