Benchmark MPNN model on the Tox21 dataset (with Masked Loss)
[4]:
import sys
sys.path.append('../../../../')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets
from molgraph.losses import MaskedBinaryCrossentropy
1. Build MolecularGraphEncoder
[2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder
atom_encoder = Featurizer([
features.Symbol(),
features.Hybridization(),
features.FormalCharge(),
features.TotalNumHs(),
features.TotalValence(),
features.NumRadicalElectrons(),
features.Degree(),
features.ChiralCenter(),
features.Aromatic(),
features.Ring(),
features.Hetero(),
features.HydrogenDonor(),
features.HydrogenAcceptor(),
features.CIPCode(),
features.ChiralCenter(),
features.RingSize(),
features.Ring(),
features.CrippenLogPContribution(),
features.CrippenMolarRefractivityContribution(),
features.TPSAContribution(),
features.LabuteASAContribution(),
features.GasteigerCharge(),
])
bond_encoder = Featurizer([
features.BondType(),
features.Conjugated(),
features.Rotatable(),
features.Ring(),
features.Stereo(),
])
encoder = MolecularGraphEncoder(
atom_encoder,
bond_encoder,
positional_encoding_dim=16,
self_loops=False
)
2. Build TF dataset from MolecularGraphEncoder
[3]:
tox21 = datasets.get('tox21')
x_train = encoder(tox21['train']['x'])
y_train = tox21['train']['y']
y_mask_train = tox21['train']['y_mask']
x_val = encoder(tox21['validation']['x'])
y_val = tox21['validation']['y']
y_mask_val = tox21['validation']['y_mask']
x_test = encoder(tox21['test']['x'])
y_test = tox21['test']['y']
y_mask_test = tox21['test']['y_mask']
type_spec = x_train.spec
[5]:
train_ds = (
tf.data.Dataset.from_tensor_slices((x_train, y_train, y_mask_train))
.shuffle(1024)
.batch(32)
.prefetch(-1)
)
val_ds = (
tf.data.Dataset.from_tensor_slices((x_val, y_val, y_mask_val))
.batch(32)
.prefetch(-1)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test, y_mask_test))
.batch(32)
.prefetch(-1)
)
3. Modeling
[7]:
from molgraph.layers import MPNNConv
from molgraph.layers import LaplacianPositionalEncoding
from molgraph.layers import SetGatherReadout
from molgraph.layers import MinMaxScaling
node_preprocessing = MinMaxScaling(
feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
feature='edge_feature', feature_range=(0, 1), threshold=True)
node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))
model = tf.keras.Sequential([
keras.layers.Input(type_spec=type_spec),
node_preprocessing,
edge_preprocessing,
LaplacianPositionalEncoding(),
MPNNConv(normalization='batch_norm'),
MPNNConv(normalization='batch_norm'),
MPNNConv(normalization='batch_norm'),
SetGatherReadout(),
keras.layers.Dense(1024, 'relu'),
keras.layers.Dense(1024, 'relu'),
keras.layers.Dense(y_train.shape[-1], 'sigmoid')
])
optimizer = keras.optimizers.Adam(1e-4)
loss = MaskedBinaryCrossentropy(name='bce')
metrics = [
# AUC deals with masks
keras.metrics.AUC(name='roc_auc', multi_label=True)
]
callbacks = [
keras.callbacks.ReduceLROnPlateau(
monitor='val_roc_auc',
factor=0.1,
patience=5,
min_lr=1e-6,
mode='max',
),
keras.callbacks.EarlyStopping(
monitor='val_roc_auc',
patience=10,
mode='max',
restore_best_weights=True,
)
]
model.compile(optimizer, loss, weighted_metrics=metrics)
history = model.fit(
train_ds,
callbacks=callbacks,
validation_data=val_ds,
epochs=100,
verbose=2,
)
score = model.evaluate(test_ds)
print(score)
Epoch 1/100
196/196 - 20s - loss: 0.2953 - roc_auc: 0.5798 - val_loss: 0.2795 - val_roc_auc: 0.6407 - lr: 1.0000e-04 - 20s/epoch - 103ms/step
Epoch 2/100
196/196 - 13s - loss: 0.2739 - roc_auc: 0.6427 - val_loss: 0.2414 - val_roc_auc: 0.7301 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 3/100
196/196 - 13s - loss: 0.2616 - roc_auc: 0.6961 - val_loss: 0.2321 - val_roc_auc: 0.7534 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 4/100
196/196 - 13s - loss: 0.2544 - roc_auc: 0.7180 - val_loss: 0.2270 - val_roc_auc: 0.7654 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 5/100
196/196 - 13s - loss: 0.2513 - roc_auc: 0.7234 - val_loss: 0.2551 - val_roc_auc: 0.6889 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 6/100
196/196 - 13s - loss: 0.2476 - roc_auc: 0.7372 - val_loss: 0.2193 - val_roc_auc: 0.7653 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 7/100
196/196 - 13s - loss: 0.2396 - roc_auc: 0.7570 - val_loss: 0.2144 - val_roc_auc: 0.8030 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 8/100
196/196 - 14s - loss: 0.2336 - roc_auc: 0.7711 - val_loss: 0.2159 - val_roc_auc: 0.8067 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 9/100
196/196 - 13s - loss: 0.2269 - roc_auc: 0.7851 - val_loss: 0.2066 - val_roc_auc: 0.8119 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 10/100
196/196 - 13s - loss: 0.2259 - roc_auc: 0.7859 - val_loss: 0.2018 - val_roc_auc: 0.8181 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 11/100
196/196 - 13s - loss: 0.2173 - roc_auc: 0.8001 - val_loss: 0.2020 - val_roc_auc: 0.8195 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 12/100
196/196 - 13s - loss: 0.2129 - roc_auc: 0.8072 - val_loss: 0.1937 - val_roc_auc: 0.8251 - lr: 1.0000e-04 - 13s/epoch - 69ms/step
Epoch 13/100
196/196 - 13s - loss: 0.2077 - roc_auc: 0.8152 - val_loss: 0.1953 - val_roc_auc: 0.8307 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 14/100
196/196 - 13s - loss: 0.2052 - roc_auc: 0.8196 - val_loss: 0.2205 - val_roc_auc: 0.7939 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 15/100
196/196 - 13s - loss: 0.2074 - roc_auc: 0.8175 - val_loss: 0.1950 - val_roc_auc: 0.8337 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 16/100
196/196 - 13s - loss: 0.1984 - roc_auc: 0.8299 - val_loss: 0.1907 - val_roc_auc: 0.8428 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 17/100
196/196 - 14s - loss: 0.1956 - roc_auc: 0.8349 - val_loss: 0.1897 - val_roc_auc: 0.8395 - lr: 1.0000e-04 - 14s/epoch - 69ms/step
Epoch 18/100
196/196 - 13s - loss: 0.1925 - roc_auc: 0.8373 - val_loss: 0.1957 - val_roc_auc: 0.8390 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 19/100
196/196 - 13s - loss: 0.1893 - roc_auc: 0.8456 - val_loss: 0.1918 - val_roc_auc: 0.8321 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 20/100
196/196 - 14s - loss: 0.1846 - roc_auc: 0.8509 - val_loss: 0.1845 - val_roc_auc: 0.8433 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 21/100
196/196 - 14s - loss: 0.1798 - roc_auc: 0.8556 - val_loss: 0.1935 - val_roc_auc: 0.8357 - lr: 1.0000e-04 - 14s/epoch - 73ms/step
Epoch 22/100
196/196 - 14s - loss: 0.1796 - roc_auc: 0.8556 - val_loss: 0.1967 - val_roc_auc: 0.8243 - lr: 1.0000e-04 - 14s/epoch - 71ms/step
Epoch 23/100
196/196 - 13s - loss: 0.1852 - roc_auc: 0.8493 - val_loss: 0.1884 - val_roc_auc: 0.8417 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 24/100
196/196 - 13s - loss: 0.1773 - roc_auc: 0.8613 - val_loss: 0.1896 - val_roc_auc: 0.8332 - lr: 1.0000e-04 - 13s/epoch - 68ms/step
Epoch 25/100
196/196 - 14s - loss: 0.1738 - roc_auc: 0.8630 - val_loss: 0.1889 - val_roc_auc: 0.8374 - lr: 1.0000e-04 - 14s/epoch - 70ms/step
Epoch 26/100
196/196 - 13s - loss: 0.1610 - roc_auc: 0.8809 - val_loss: 0.1820 - val_roc_auc: 0.8523 - lr: 1.0000e-05 - 13s/epoch - 69ms/step
Epoch 27/100
196/196 - 13s - loss: 0.1568 - roc_auc: 0.8856 - val_loss: 0.1822 - val_roc_auc: 0.8463 - lr: 1.0000e-05 - 13s/epoch - 69ms/step
Epoch 28/100
196/196 - 13s - loss: 0.1556 - roc_auc: 0.8864 - val_loss: 0.1839 - val_roc_auc: 0.8342 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 29/100
196/196 - 13s - loss: 0.1546 - roc_auc: 0.8869 - val_loss: 0.1832 - val_roc_auc: 0.8461 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 30/100
196/196 - 13s - loss: 0.1539 - roc_auc: 0.8891 - val_loss: 0.1848 - val_roc_auc: 0.8429 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 31/100
196/196 - 13s - loss: 0.1522 - roc_auc: 0.8900 - val_loss: 0.1871 - val_roc_auc: 0.8400 - lr: 1.0000e-05 - 13s/epoch - 67ms/step
Epoch 32/100
196/196 - 13s - loss: 0.1514 - roc_auc: 0.8916 - val_loss: 0.1866 - val_roc_auc: 0.8410 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 33/100
196/196 - 13s - loss: 0.1509 - roc_auc: 0.8908 - val_loss: 0.1831 - val_roc_auc: 0.8476 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 34/100
196/196 - 13s - loss: 0.1510 - roc_auc: 0.8920 - val_loss: 0.1876 - val_roc_auc: 0.8430 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 35/100
196/196 - 13s - loss: 0.1507 - roc_auc: 0.8925 - val_loss: 0.1864 - val_roc_auc: 0.8394 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
Epoch 36/100
196/196 - 13s - loss: 0.1500 - roc_auc: 0.8932 - val_loss: 0.1840 - val_roc_auc: 0.8397 - lr: 1.0000e-06 - 13s/epoch - 67ms/step
25/25 [==============================] - 1s 21ms/step - loss: 0.2212 - roc_auc: 0.8150
[0.22120867669582367, 0.8149662017822266]