Benchmark GAT model on the ESOL dataset

[6]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets
from molgraph.losses import masked_losses
from molgraph.metrics import masked_metrics

1. Build MolecularGraphEncoder

[2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder

atom_encoder = Featurizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

bond_encoder = Featurizer([
    features.BondType(),
    features.Conjugated(),
    features.Rotatable(),
    features.Ring(),
    features.Stereo(),
])

encoder = MolecularGraphEncoder(
    atom_encoder,
    bond_encoder,
    positional_encoding_dim=16,
    self_loops=False
)

2. Build TF dataset from MolecularGraphEncoder

[3]:
esol = datasets.get('esol')

x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_val = encoder(esol['validation']['x'])
y_val = esol['validation']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

type_spec = x_train.spec
[4]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .batch(32)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(32)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(32)
    .prefetch(-1)
)

3. Modeling

[5]:
from molgraph.layers import GATConv
from molgraph.layers import LaplacianPositionalEncoding
from molgraph.layers import Readout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    LaplacianPositionalEncoding(),
    GATConv(normalization='batch_norm'),
    GATConv(normalization='batch_norm'),
    GATConv(normalization='batch_norm'),
    Readout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1])
])


optimizer = keras.optimizers.Adam(1e-4)
loss = keras.losses.MeanAbsoluteError(name='mae')
callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor=f'val_loss',
        factor=0.1,
        patience=10,
        min_lr=1e-6,
        mode='min',
    ),
    keras.callbacks.EarlyStopping(
        monitor=f'val_loss',
        patience=20,
        mode='min',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss)
history = model.fit(
    train_ds,
    callbacks=callbacks,
    validation_data=val_ds,
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)
Epoch 1/100
29/29 - 6s - loss: 1.6104 - val_loss: 2.8727 - lr: 1.0000e-04 - 6s/epoch - 209ms/step
Epoch 2/100
29/29 - 0s - loss: 1.0733 - val_loss: 2.8306 - lr: 1.0000e-04 - 388ms/epoch - 13ms/step
Epoch 3/100
29/29 - 0s - loss: 0.9292 - val_loss: 2.7933 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 4/100
29/29 - 0s - loss: 0.7980 - val_loss: 2.6992 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step
Epoch 5/100
29/29 - 0s - loss: 0.7400 - val_loss: 2.6170 - lr: 1.0000e-04 - 386ms/epoch - 13ms/step
Epoch 6/100
29/29 - 0s - loss: 0.6782 - val_loss: 2.5222 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step
Epoch 7/100
29/29 - 0s - loss: 0.6756 - val_loss: 2.3882 - lr: 1.0000e-04 - 392ms/epoch - 14ms/step
Epoch 8/100
29/29 - 0s - loss: 0.6457 - val_loss: 2.2180 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 9/100
29/29 - 0s - loss: 0.6356 - val_loss: 2.1035 - lr: 1.0000e-04 - 379ms/epoch - 13ms/step
Epoch 10/100
29/29 - 0s - loss: 0.5921 - val_loss: 2.0029 - lr: 1.0000e-04 - 397ms/epoch - 14ms/step
Epoch 11/100
29/29 - 0s - loss: 0.6023 - val_loss: 1.7675 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step
Epoch 12/100
29/29 - 0s - loss: 0.5469 - val_loss: 1.6520 - lr: 1.0000e-04 - 396ms/epoch - 14ms/step
Epoch 13/100
29/29 - 0s - loss: 0.5428 - val_loss: 1.4848 - lr: 1.0000e-04 - 397ms/epoch - 14ms/step
Epoch 14/100
29/29 - 0s - loss: 0.5678 - val_loss: 1.3723 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step
Epoch 15/100
29/29 - 0s - loss: 0.5733 - val_loss: 1.2860 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step
Epoch 16/100
29/29 - 0s - loss: 0.5791 - val_loss: 1.2071 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step
Epoch 17/100
29/29 - 0s - loss: 0.5403 - val_loss: 1.1961 - lr: 1.0000e-04 - 393ms/epoch - 14ms/step
Epoch 18/100
29/29 - 0s - loss: 0.5152 - val_loss: 1.1848 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step
Epoch 19/100
29/29 - 0s - loss: 0.5947 - val_loss: 1.0619 - lr: 1.0000e-04 - 404ms/epoch - 14ms/step
Epoch 20/100
29/29 - 0s - loss: 0.5636 - val_loss: 0.9673 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step
Epoch 21/100
29/29 - 0s - loss: 0.4985 - val_loss: 1.4409 - lr: 1.0000e-04 - 374ms/epoch - 13ms/step
Epoch 22/100
29/29 - 0s - loss: 0.5274 - val_loss: 1.0172 - lr: 1.0000e-04 - 376ms/epoch - 13ms/step
Epoch 23/100
29/29 - 0s - loss: 0.5385 - val_loss: 0.8310 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step
Epoch 24/100
29/29 - 0s - loss: 0.4996 - val_loss: 0.7225 - lr: 1.0000e-04 - 392ms/epoch - 14ms/step
Epoch 25/100
29/29 - 0s - loss: 0.5324 - val_loss: 0.5973 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step
Epoch 26/100
29/29 - 0s - loss: 0.4900 - val_loss: 0.6436 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 27/100
29/29 - 0s - loss: 0.4709 - val_loss: 0.6613 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 28/100
29/29 - 0s - loss: 0.4714 - val_loss: 0.6028 - lr: 1.0000e-04 - 381ms/epoch - 13ms/step
Epoch 29/100
29/29 - 0s - loss: 0.4654 - val_loss: 0.6892 - lr: 1.0000e-04 - 387ms/epoch - 13ms/step
Epoch 30/100
29/29 - 0s - loss: 0.4792 - val_loss: 0.5300 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step
Epoch 31/100
29/29 - 0s - loss: 0.4470 - val_loss: 0.6892 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step
Epoch 32/100
29/29 - 0s - loss: 0.4791 - val_loss: 0.5954 - lr: 1.0000e-04 - 375ms/epoch - 13ms/step
Epoch 33/100
29/29 - 0s - loss: 0.4156 - val_loss: 0.4791 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step
Epoch 34/100
29/29 - 0s - loss: 0.4756 - val_loss: 0.7215 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step
Epoch 35/100
29/29 - 0s - loss: 0.4974 - val_loss: 0.5423 - lr: 1.0000e-04 - 378ms/epoch - 13ms/step
Epoch 36/100
29/29 - 0s - loss: 0.4495 - val_loss: 0.5344 - lr: 1.0000e-04 - 381ms/epoch - 13ms/step
Epoch 37/100
29/29 - 0s - loss: 0.5167 - val_loss: 0.8142 - lr: 1.0000e-04 - 374ms/epoch - 13ms/step
Epoch 38/100
29/29 - 0s - loss: 0.4763 - val_loss: 0.5238 - lr: 1.0000e-04 - 376ms/epoch - 13ms/step
Epoch 39/100
29/29 - 0s - loss: 0.4450 - val_loss: 0.5957 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step
Epoch 40/100
29/29 - 0s - loss: 0.4635 - val_loss: 0.5614 - lr: 1.0000e-04 - 384ms/epoch - 13ms/step
Epoch 41/100
29/29 - 0s - loss: 0.4348 - val_loss: 0.4831 - lr: 1.0000e-04 - 387ms/epoch - 13ms/step
Epoch 42/100
29/29 - 0s - loss: 0.3967 - val_loss: 0.5873 - lr: 1.0000e-04 - 394ms/epoch - 14ms/step
Epoch 43/100
29/29 - 0s - loss: 0.4132 - val_loss: 0.5240 - lr: 1.0000e-04 - 372ms/epoch - 13ms/step
Epoch 44/100
29/29 - 0s - loss: 0.3935 - val_loss: 0.4802 - lr: 1.0000e-05 - 375ms/epoch - 13ms/step
Epoch 45/100
29/29 - 0s - loss: 0.3535 - val_loss: 0.4594 - lr: 1.0000e-05 - 388ms/epoch - 13ms/step
Epoch 46/100
29/29 - 0s - loss: 0.3636 - val_loss: 0.4413 - lr: 1.0000e-05 - 393ms/epoch - 14ms/step
Epoch 47/100
29/29 - 0s - loss: 0.3381 - val_loss: 0.4438 - lr: 1.0000e-05 - 383ms/epoch - 13ms/step
Epoch 48/100
29/29 - 0s - loss: 0.3362 - val_loss: 0.4697 - lr: 1.0000e-05 - 378ms/epoch - 13ms/step
Epoch 49/100
29/29 - 0s - loss: 0.3738 - val_loss: 0.4590 - lr: 1.0000e-05 - 373ms/epoch - 13ms/step
Epoch 50/100
29/29 - 0s - loss: 0.3348 - val_loss: 0.4344 - lr: 1.0000e-05 - 388ms/epoch - 13ms/step
Epoch 51/100
29/29 - 0s - loss: 0.3483 - val_loss: 0.4395 - lr: 1.0000e-05 - 382ms/epoch - 13ms/step
Epoch 52/100
29/29 - 0s - loss: 0.3833 - val_loss: 0.4444 - lr: 1.0000e-05 - 384ms/epoch - 13ms/step
Epoch 53/100
29/29 - 0s - loss: 0.3380 - val_loss: 0.4358 - lr: 1.0000e-05 - 373ms/epoch - 13ms/step
Epoch 54/100
29/29 - 0s - loss: 0.3517 - val_loss: 0.4578 - lr: 1.0000e-05 - 378ms/epoch - 13ms/step
Epoch 55/100
29/29 - 0s - loss: 0.3465 - val_loss: 0.4576 - lr: 1.0000e-05 - 374ms/epoch - 13ms/step
Epoch 56/100
29/29 - 0s - loss: 0.3377 - val_loss: 0.4544 - lr: 1.0000e-05 - 387ms/epoch - 13ms/step
Epoch 57/100
29/29 - 0s - loss: 0.3469 - val_loss: 0.4457 - lr: 1.0000e-05 - 384ms/epoch - 13ms/step
Epoch 58/100
29/29 - 0s - loss: 0.3298 - val_loss: 0.4854 - lr: 1.0000e-05 - 372ms/epoch - 13ms/step
Epoch 59/100
29/29 - 0s - loss: 0.3758 - val_loss: 0.4491 - lr: 1.0000e-05 - 372ms/epoch - 13ms/step
Epoch 60/100
29/29 - 0s - loss: 0.3519 - val_loss: 0.4462 - lr: 1.0000e-05 - 375ms/epoch - 13ms/step
Epoch 61/100
29/29 - 0s - loss: 0.3234 - val_loss: 0.4383 - lr: 1.0000e-06 - 388ms/epoch - 13ms/step
Epoch 62/100
29/29 - 0s - loss: 0.3429 - val_loss: 0.4395 - lr: 1.0000e-06 - 387ms/epoch - 13ms/step
Epoch 63/100
29/29 - 0s - loss: 0.3258 - val_loss: 0.4392 - lr: 1.0000e-06 - 386ms/epoch - 13ms/step
Epoch 64/100
29/29 - 0s - loss: 0.3527 - val_loss: 0.4561 - lr: 1.0000e-06 - 414ms/epoch - 14ms/step
Epoch 65/100
29/29 - 0s - loss: 0.3089 - val_loss: 0.4505 - lr: 1.0000e-06 - 375ms/epoch - 13ms/step
Epoch 66/100
29/29 - 0s - loss: 0.3431 - val_loss: 0.4419 - lr: 1.0000e-06 - 375ms/epoch - 13ms/step
Epoch 67/100
29/29 - 0s - loss: 0.3460 - val_loss: 0.4373 - lr: 1.0000e-06 - 385ms/epoch - 13ms/step
Epoch 68/100
29/29 - 0s - loss: 0.3568 - val_loss: 0.4480 - lr: 1.0000e-06 - 387ms/epoch - 13ms/step
Epoch 69/100
29/29 - 0s - loss: 0.3029 - val_loss: 0.4462 - lr: 1.0000e-06 - 383ms/epoch - 13ms/step
Epoch 70/100
29/29 - 0s - loss: 0.3390 - val_loss: 0.4427 - lr: 1.0000e-06 - 390ms/epoch - 13ms/step
4/4 [==============================] - 0s 4ms/step - loss: 0.4317
0.4317157566547394