Graph tensor

Import modules

[1]:
from molgraph import chemistry
from molgraph import layers
from molgraph import GraphTensor #####

import tensorflow as tf
from tensorflow import keras

Construct a GraphTensor

Although a GraphTensor can be constructed directly from its constructor, here we construct a GraphTensor from a MolecularGraphEncoder.

[2]:
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol({'C', 'N', 'O'}, oov_size=1),
    chemistry.features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),
    chemistry.features.HydrogenDonor(),
    chemistry.features.HydrogenAcceptor(),
    chemistry.features.Hetero()
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),
    chemistry.features.Rotatable()
])

mol_encoder = chemistry.MolecularGraphEncoder(
    atom_encoder, bond_encoder, positional_encoding_dim=None)

smiles_list = [
    'OCC1OC(C(C1O)O)n1cnc2c1ncnc2N',
    'C(C(=O)O)N',
    'C1=CC(=CC=C1CC(C(=O)O)N)O'
]

graph_tensor = mol_encoder(smiles_list)

print(graph_tensor)
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)

.separate() – Separate subgraphs of GraphTensor

[3]:
graph_tensor = graph_tensor.separate()
print(graph_tensor)
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(3, None, 11), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_feature=<tf.RaggedTensor: shape=(3, None, 5), dtype=float32, ragged_rank=1>)

.merge() – Merge subgraphs of GraphTensor

[4]:
graph_tensor = graph_tensor.merge()
print(graph_tensor)
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 11), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>)

.propagate() – Propagate node informaton with the GraphTensor

[5]:
print('Node features before:\n', graph_tensor.node_feature, end='\n\n')
graph_tensor = graph_tensor.propagate()
print('Node features after:\n', graph_tensor.node_feature)
Node features before:
 tf.Tensor(
[[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]
 [0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]
 [0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1.]], shape=(37, 11), dtype=float32)

Node features after:
 tf.Tensor(
[[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 2. 1. 1. 1.]
 [0. 2. 0. 1. 0. 0. 0. 3. 0. 1. 1.]
 [0. 2. 0. 0. 0. 0. 0. 2. 0. 0. 0.]
 [0. 1. 1. 1. 0. 0. 1. 2. 0. 2. 2.]
 [0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]
 [0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]
 [0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 2. 1. 0. 0. 0. 3. 0. 0. 1. 1.]
 [0. 1. 2. 0. 0. 0. 3. 0. 0. 2. 2.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 1. 2. 0. 0. 0. 3. 0. 1. 2. 2.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 2. 0. 1. 0. 0. 3. 0. 1. 1. 1.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
 [0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]
 [0. 2. 0. 0. 0. 0. 1. 1. 0. 0. 0.]
 [0. 2. 1. 0. 0. 0. 1. 2. 1. 1. 1.]
 [0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]], shape=(37, 11), dtype=float32)

.update() – Update data of the GraphTensor

[6]:
node_supplementary_data = tf.random.uniform(
    shape=graph_tensor.node_feature.shape[:-1] + [1])

node_feature_updated = tf.random.uniform(
    shape=graph_tensor.node_feature.shape[:-1] + [128])

# Both add new data and update existing data of the GraphTensor:
graph_tensor = graph_tensor.update({
    'node_supplementary_data': node_supplementary_data,
    'node_feature': node_feature_updated
})
print(graph_tensor)
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(76, 5), dtype=float32>,
  node_supplementary_data=<tf.Tensor: shape=(37, 1), dtype=float32>)

.remove() – Remove data from GraphTensor

[7]:
graph_tensor = graph_tensor.remove([
    'node_supplementary_data', 'edge_feature'
])
print(graph_tensor)
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)

__getitem__ – Index lookup on the GraphTensor

The GraphTensor can be indexed either by passing a str (to obtain a specific field of GraphTensor) or int, list[int] or slice (to extract specific subgraphs (molecules) from GraphTensor). (Alternatively, the GraphTensor can be passed to tf.gather to extract specific subgraphs.)

[8]:
print("Complete graph:\n")
print("---" * 20)
print(graph_tensor, end='\n\n')

print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[[1, 2]], end='\n\n')

print("---" * 20)
print("Subgraph (2) and (3) of graph:\n")
print(graph_tensor[:2], end='\n\n')
Complete graph:

------------------------------------------------------------
GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)

------------------------------------------------------------
Subgraph (2) and (3) of graph:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(18, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(34,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(34,), dtype=int32>)

------------------------------------------------------------
Subgraph (2) and (3) of graph:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)

__getattr__ – Attribute lookup on the GraphTensor

[9]:
print("Access `node_feature` field:\n")
print("---" * 20)
print(graph_tensor.node_feature, end='\n\n')

print("---" * 20)
print("Access `edge_src` field:\n")
print(graph_tensor.edge_src, end='\n\n')

print("---" * 20)
print("Access `graph_indicator` field:\n")
print(graph_tensor.graph_indicator, end='\n\n')
Access `node_feature` field:

------------------------------------------------------------
tf.Tensor(
[[0.30606592 0.01332998 0.28550065 ... 0.30522108 0.43709052 0.2496804 ]
 [0.47505558 0.6802629  0.12628877 ... 0.54731417 0.85908985 0.01080072]
 [0.32505012 0.16541815 0.9268564  ... 0.19977057 0.6975106  0.63107324]
 ...
 [0.06981373 0.0497787  0.7329197  ... 0.72168195 0.992267   0.4002931 ]
 [0.6254629  0.77454865 0.4750824  ... 0.21217322 0.10769343 0.71567035]
 [0.29524624 0.7836231  0.7198993  ... 0.94255567 0.926514   0.62505746]], shape=(37, 128), dtype=float32)

------------------------------------------------------------
Access `edge_src` field:

tf.Tensor(
[ 0  1  1  2  2  2  3  3  4  4  4  5  5  5  6  6  6  7  8  9  9  9 10 10
 11 11 12 12 12 13 13 13 14 14 15 15 16 16 17 17 17 18 19 19 20 20 20 21
 22 23 24 24 25 25 26 26 26 27 27 28 28 29 29 29 30 30 31 31 31 32 32 32
 33 34 35 36], shape=(76,), dtype=int32)

------------------------------------------------------------
Access `graph_indicator` field:

tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2], shape=(37,), dtype=int64)

tf.concat – Concatenating multiple GraphTensor instances

[10]:
print("Concatenating two graphs in non-ragged states:\n")
graph_tensor_concat = tf.concat([
    graph_tensor,
    graph_tensor], axis=0)
print(graph_tensor_concat, end='\n\n')
print("Inspect `graph_indicator`:\n")
print(graph_tensor_concat.graph_indicator, end='\n\n')

print('---' * 20)
print("Concatenating two graphs in ragged states")
graph_tensor_concat = tf.concat([
    graph_tensor.separate(),
    graph_tensor.separate()], axis=0)
print(graph_tensor_concat, end='\n\n')
Concatenating two graphs in non-ragged states:

GraphTensor(
  sizes=<tf.Tensor: shape=(6,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(74, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(152,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(152,), dtype=int32>)

Inspect `graph_indicator`:

tf.Tensor(
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5], shape=(74,), dtype=int64)

------------------------------------------------------------
Concatenating two graphs in ragged states
GraphTensor(
  sizes=<tf.Tensor: shape=(6,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(6, None, 128), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(6, None), dtype=int32, ragged_rank=1>)

tf.split – Splits a GraphTensor into multiple GraphTensor instances

[11]:
tf.split(graph_tensor_concat.merge(), num_or_size_splits=6)
[11]:
[GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(5, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(8,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(8,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(26,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(19, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(42,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(42,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(5, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(8,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(8,), dtype=int32>),
 GraphTensor(
   sizes=<tf.Tensor: shape=(1,), dtype=int64>,
   node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
   edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
   edge_dst=<tf.Tensor: shape=(26,), dtype=int32>)]

.spec – The spec of the GraphTensor

[12]:
print(graph_tensor.spec)
GraphTensor.Spec(sizes=TensorSpec(shape=(None,), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})

.shape – Partial shape of the GraphTensor

[13]:
print('(partial) shape:', graph_tensor.shape)
(partial) shape: (3, None, 128)

.dtype – Partial dtype of the GraphTensor

[14]:
print('(partial) dtype:', graph_tensor.dtype.name)
(partial) dtype: float32

.rank – Partial rank of the GraphTensor

[15]:
print('(partial) rank: ', graph_tensor.rank)
(partial) rank:  3

tf.data.Dataset – Creating a TF dataset from a GraphTensor

[16]:
ds = tf.data.Dataset.from_tensor_slices(graph_tensor)
print('Dataset object:\n', ds)

print('\n' + '---' * 20)
# Loop over dataset
for i, x in enumerate(ds.batch(2).map(lambda x: x)):
    print(f"\nbatch {i + 1}:\n")
    print(x)
    print('\n' + '---' * 20)
Dataset object:
 <_TensorSliceDataset element_spec=GraphTensor.Spec(sizes=TensorSpec(shape=(), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})>

------------------------------------------------------------

batch 1:

GraphTensor(
  sizes=<tf.Tensor: shape=(2,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(24, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(50,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(50,), dtype=int32>)

------------------------------------------------------------

batch 2:

GraphTensor(
  sizes=<tf.Tensor: shape=(1,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(13, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(26,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(26,), dtype=int32>)

------------------------------------------------------------

layers – Passing a GraphTensor to a layer

The GraphTensor can be passed to GNN layers either as a single disjoint graph or subgraphs.

[17]:
gin_conv = layers.GINConv(128)

print("Pass GraphTensor in non-ragged state:\n")
print(gin_conv(graph_tensor), end='\n\n')
print('---' * 20)
print('\nPass GraphTensor in ragged state:\n')
print(gin_conv(graph_tensor.separate()), end='\n\n')
Pass GraphTensor in non-ragged state:

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(37, 128), dtype=float32>,
  edge_src=<tf.Tensor: shape=(76,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(76,), dtype=int32>)

------------------------------------------------------------

Pass GraphTensor in ragged state:

GraphTensor(
  sizes=<tf.Tensor: shape=(3,), dtype=int64>,
  node_feature=<tf.RaggedTensor: shape=(3, None, 128), dtype=float32, ragged_rank=1>,
  edge_src=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>,
  edge_dst=<tf.RaggedTensor: shape=(3, None), dtype=int32, ragged_rank=1>)

Model – Passing a GraphTensor to a model

[18]:
model = tf.keras.Sequential([
    layers.GCNConv(),
    layers.GCNConv(),
    layers.Readout(),
    keras.layers.Dense(1)
])

y_dummy = tf.constant([[1.], [2.], [3.]])


model.compile('adam', 'huber')
print("Using (graph_tensor, label) pair as input:\n")
model.fit(graph_tensor, y_dummy, batch_size=2, epochs=5)

print('\n------------------------------\n')
print("Using tf.data.Dataset as input:\n")
dataset = tf.data.Dataset.from_tensor_slices((graph_tensor, y_dummy))
model.fit(dataset.batch(2), epochs=5);
Using (graph_tensor, label) pair as input:

Epoch 1/5
2/2 [==============================] - 3s 8ms/step - loss: 0.3226
Epoch 2/5
2/2 [==============================] - 0s 7ms/step - loss: 8.0285
Epoch 3/5
2/2 [==============================] - 0s 8ms/step - loss: 4.7673
Epoch 4/5
2/2 [==============================] - 0s 6ms/step - loss: 1.8421
Epoch 5/5
2/2 [==============================] - 0s 6ms/step - loss: 0.1327

------------------------------

Using tf.data.Dataset as input:

Epoch 1/5
2/2 [==============================] - 0s 7ms/step - loss: 0.8891
Epoch 2/5
2/2 [==============================] - 0s 9ms/step - loss: 1.2614
Epoch 3/5
2/2 [==============================] - 0s 6ms/step - loss: 1.0532
Epoch 4/5
2/2 [==============================] - 0s 6ms/step - loss: 0.6833
Epoch 5/5
2/2 [==============================] - 0s 6ms/step - loss: 0.4317
[ ]: