Benchmark MPNN model on the Tox21 dataset (with Masked Loss)

[4]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets
from molgraph.losses import MaskedBinaryCrossentropy

1. Build MolecularGraphEncoder

[2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder

atom_encoder = Featurizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

bond_encoder = Featurizer([
    features.BondType(),
    features.Conjugated(),
    features.Rotatable(),
    features.Ring(),
    features.Stereo(),
])

encoder = MolecularGraphEncoder(
    atom_encoder,
    bond_encoder,
    positional_encoding_dim=16,
    self_loops=False
)

2. Build TF dataset from MolecularGraphEncoder

[3]:
tox21 = datasets.get('tox21')

x_train = encoder(tox21['train']['x'])
y_train = tox21['train']['y']
y_mask_train = tox21['train']['y_mask']

x_val = encoder(tox21['validation']['x'])
y_val = tox21['validation']['y']
y_mask_val = tox21['validation']['y_mask']

x_test = encoder(tox21['test']['x'])
y_test = tox21['test']['y']
y_mask_test = tox21['test']['y_mask']

type_spec = x_train.spec
[5]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train, y_mask_train))
    .shuffle(1024)
    .batch(32)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val, y_mask_val))
    .batch(32)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test, y_mask_test))
    .batch(32)
    .prefetch(-1)
)

3. Modeling

[7]:
from molgraph.layers import MPNNConv
from molgraph.layers import LaplacianPositionalEncoding
from molgraph.layers import SetGatherReadout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    LaplacianPositionalEncoding(),
    MPNNConv(normalization='batch_norm'),
    MPNNConv(normalization='batch_norm'),
    MPNNConv(normalization='batch_norm'),
    SetGatherReadout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1], 'sigmoid')
])


optimizer = keras.optimizers.Adam(1e-4)
loss = MaskedBinaryCrossentropy(name='bce')
metrics = [
    # AUC deals with masks
    keras.metrics.AUC(name='roc_auc', multi_label=True)
]

callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_roc_auc',
        factor=0.1,
        patience=5,
        min_lr=1e-6,
        mode='max',
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_roc_auc',
        patience=10,
        mode='max',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss, weighted_metrics=metrics)
history = model.fit(
    train_ds,
    callbacks=callbacks,
    validation_data=val_ds,
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)
Epoch 1/100
196/196 - 20s - loss: 0.2953 - roc_auc: 0.5798 - val_loss: 0.2795 - val_roc_auc: 0.6407 - lr: 1.0000e-04 - 20s/epoch - 103ms/step
Epoch 2/100
196/196 - 13s - loss: 0.2739 - roc_auc: 0.6427 - val_loss: 0.2414 - val_roc_auc: 0.7301 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 3/100
196/196 - 13s - loss: 0.2616 - roc_auc: 0.6961 - val_loss: 0.2321 - val_roc_auc: 0.7534 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 4/100
196/196 - 13s - loss: 0.2544 - roc_auc: 0.7180 - val_loss: 0.2270 - val_roc_auc: 0.7654 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 5/100
196/196 - 13s - loss: 0.2513 - roc_auc: 0.7234 - val_loss: 0.2551 - val_roc_auc: 0.6889 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 6/100
196/196 - 13s - loss: 0.2476 - roc_auc: 0.7372 - val_loss: 0.2193 - val_roc_auc: 0.7653 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 7/100
196/196 - 13s - loss: 0.2396 - roc_auc: 0.7570 - val_loss: 0.2144 - val_roc_auc: 0.8030 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 8/100
196/196 - 14s - loss: 0.2336 - roc_auc: 0.7711 - val_loss: 0.2159 - val_roc_auc: 0.8067 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 9/100
196/196 - 13s - loss: 0.2269 - roc_auc: 0.7851 - val_loss: 0.2066 - val_roc_auc: 0.8119 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 10/100
196/196 - 13s - loss: 0.2259 - roc_auc: 0.7859 - val_loss: 0.2018 - val_roc_auc: 0.8181 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 11/100
196/196 - 13s - loss: 0.2173 - roc_auc: 0.8001 - val_loss: 0.2020 - val_roc_auc: 0.8195 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 12/100
196/196 - 13s - loss: 0.2129 - roc_auc: 0.8072 - val_loss: 0.1937 - val_roc_auc: 0.8251 - lr: 1.0000e-04 - 13s/epoch - 69ms/step
Epoch 13/100
196/196 - 13s - loss: 0.2077 - roc_auc: 0.8152 - val_loss: 0.1953 - val_roc_auc: 0.8307 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 14/100
196/196 - 13s - loss: 0.2052 - roc_auc: 0.8196 - val_loss: 0.2205 - val_roc_auc: 0.7939 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 15/100
196/196 - 13s - loss: 0.2074 - roc_auc: 0.8175 - val_loss: 0.1950 - val_roc_auc: 0.8337 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 16/100
196/196 - 13s - loss: 0.1984 - roc_auc: 0.8299 - val_loss: 0.1907 - val_roc_auc: 0.8428 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 17/100
196/196 - 14s - loss: 0.1956 - roc_auc: 0.8349 - val_loss: 0.1897 - val_roc_auc: 0.8395 - lr: 1.0000e-04 - 14s/epoch - 69ms/step
Epoch 18/100
196/196 - 13s - loss: 0.1925 - roc_auc: 0.8373 - val_loss: 0.1957 - val_roc_auc: 0.8390 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 19/100
196/196 - 13s - loss: 0.1893 - roc_auc: 0.8456 - val_loss: 0.1918 - val_roc_auc: 0.8321 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 20/100
196/196 - 14s - loss: 0.1846 - roc_auc: 0.8509 - val_loss: 0.1845 - val_roc_auc: 0.8433 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 21/100
196/196 - 14s - loss: 0.1798 - roc_auc: 0.8556 - val_loss: 0.1935 - val_roc_auc: 0.8357 - lr: 1.0000e-04 - 14s/epoch - 73ms/step
Epoch 22/100
196/196 - 14s - loss: 0.1796 - roc_auc: 0.8556 - val_loss: 0.1967 - val_roc_auc: 0.8243 - lr: 1.0000e-04 - 14s/epoch - 71ms/step
Epoch 23/100
196/196 - 13s - loss: 0.1852 - roc_auc: 0.8493 - val_loss: 0.1884 - val_roc_auc: 0.8417 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 24/100
196/196 - 13s - loss: 0.1773 - roc_auc: 0.8613 - val_loss: 0.1896 - val_roc_auc: 0.8332 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 25/100
196/196 - 14s - loss: 0.1738 - roc_auc: 0.8630 - val_loss: 0.1889 - val_roc_auc: 0.8374 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 26/100
196/196 - 13s - loss: 0.1610 - roc_auc: 0.8809 - val_loss: 0.1820 - val_roc_auc: 0.8523 - lr: 1.0000e-05 - 13s/epoch - 69ms/step
Epoch 27/100
196/196 - 13s - loss: 0.1568 - roc_auc: 0.8856 - val_loss: 0.1822 - val_roc_auc: 0.8463 - lr: 1.0000e-05 - 13s/epoch - 69ms/step
Epoch 28/100
196/196 - 13s - loss: 0.1556 - roc_auc: 0.8864 - val_loss: 0.1839 - val_roc_auc: 0.8342 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 29/100
196/196 - 13s - loss: 0.1546 - roc_auc: 0.8869 - val_loss: 0.1832 - val_roc_auc: 0.8461 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 30/100
196/196 - 13s - loss: 0.1539 - roc_auc: 0.8891 - val_loss: 0.1848 - val_roc_auc: 0.8429 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 31/100
196/196 - 13s - loss: 0.1522 - roc_auc: 0.8900 - val_loss: 0.1871 - val_roc_auc: 0.8400 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 32/100
196/196 - 13s - loss: 0.1514 - roc_auc: 0.8916 - val_loss: 0.1866 - val_roc_auc: 0.8410 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 33/100
196/196 - 13s - loss: 0.1509 - roc_auc: 0.8908 - val_loss: 0.1831 - val_roc_auc: 0.8476 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 34/100
196/196 - 13s - loss: 0.1510 - roc_auc: 0.8920 - val_loss: 0.1876 - val_roc_auc: 0.8430 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 35/100
196/196 - 13s - loss: 0.1507 - roc_auc: 0.8925 - val_loss: 0.1864 - val_roc_auc: 0.8394 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 36/100
196/196 - 13s - loss: 0.1500 - roc_auc: 0.8932 - val_loss: 0.1840 - val_roc_auc: 0.8397 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
25/25 [==============================] - 1s 21ms/step - loss: 0.2212 - roc_auc: 0.8150
[0.22120867669582367, 0.8149662017822266]