Benchmark DTNN model on the QM7 dataset

[7]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets

1. Build MolecularGraphEncoder3D

[2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder3D

atom_encoder = Featurizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

encoder = MolecularGraphEncoder3D(
    atom_encoder,
    conformer_generator=None, # qm7 encodes conformers
    edge_radius=None, # max radius
    coulomb=True,
)

2. Build TF dataset from MolecularGraphEncoder3D

[3]:
qm7 = datasets.get('qm7')

x_train = encoder(qm7['train']['x'])
y_train = qm7['train']['y']

x_val = encoder(qm7['validation']['x'])
y_val = qm7['validation']['y']

x_test = encoder(qm7['test']['x'])
y_test = qm7['test']['y']

type_spec = x_train.spec
[4]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .batch(32)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(32)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(32)
    .prefetch(-1)
)

3. Modeling

[6]:
from molgraph.layers import DTNNConv
from molgraph.layers import Readout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    DTNNConv(normalization='batch_norm'),
    DTNNConv(normalization='batch_norm'),
    DTNNConv(normalization='batch_norm'),
    Readout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1])
])


optimizer = keras.optimizers.Adam(1e-4)
loss = keras.losses.MeanAbsoluteError(name='mae')

callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=10,
        min_lr=1e-6,
        mode='min',
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=20,
        mode='min',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss)
history = model.fit(
    train_ds,
    callbacks=callbacks,
    validation_data=val_ds,
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)
Epoch 1/100
180/180 - 3s - loss: 1257.5872 - val_loss: 903.2003 - lr: 1.0000e-04 - 3s/epoch - 18ms/step
Epoch 2/100
180/180 - 1s - loss: 545.0237 - val_loss: 442.4483 - lr: 1.0000e-04 - 937ms/epoch - 5ms/step
Epoch 3/100
180/180 - 1s - loss: 163.0737 - val_loss: 219.5774 - lr: 1.0000e-04 - 961ms/epoch - 5ms/step
Epoch 4/100
180/180 - 1s - loss: 127.5755 - val_loss: 100.0004 - lr: 1.0000e-04 - 952ms/epoch - 5ms/step
Epoch 5/100
180/180 - 1s - loss: 120.0276 - val_loss: 72.0875 - lr: 1.0000e-04 - 924ms/epoch - 5ms/step
Epoch 6/100
180/180 - 1s - loss: 108.8816 - val_loss: 179.9970 - lr: 1.0000e-04 - 940ms/epoch - 5ms/step
Epoch 7/100
180/180 - 1s - loss: 101.8370 - val_loss: 80.3074 - lr: 1.0000e-04 - 950ms/epoch - 5ms/step
Epoch 8/100
180/180 - 1s - loss: 97.8691 - val_loss: 66.1778 - lr: 1.0000e-04 - 920ms/epoch - 5ms/step
Epoch 9/100
180/180 - 1s - loss: 92.5363 - val_loss: 60.9984 - lr: 1.0000e-04 - 939ms/epoch - 5ms/step
Epoch 10/100
180/180 - 1s - loss: 87.6198 - val_loss: 69.3062 - lr: 1.0000e-04 - 933ms/epoch - 5ms/step
Epoch 11/100
180/180 - 1s - loss: 82.5370 - val_loss: 107.2218 - lr: 1.0000e-04 - 891ms/epoch - 5ms/step
Epoch 12/100
180/180 - 1s - loss: 79.2450 - val_loss: 57.9821 - lr: 1.0000e-04 - 889ms/epoch - 5ms/step
Epoch 13/100
180/180 - 1s - loss: 78.0266 - val_loss: 73.6456 - lr: 1.0000e-04 - 887ms/epoch - 5ms/step
Epoch 14/100
180/180 - 1s - loss: 72.0752 - val_loss: 43.4730 - lr: 1.0000e-04 - 904ms/epoch - 5ms/step
Epoch 15/100
180/180 - 1s - loss: 70.5660 - val_loss: 89.5029 - lr: 1.0000e-04 - 881ms/epoch - 5ms/step
Epoch 16/100
180/180 - 1s - loss: 65.7411 - val_loss: 46.4776 - lr: 1.0000e-04 - 880ms/epoch - 5ms/step
Epoch 17/100
180/180 - 1s - loss: 62.6743 - val_loss: 92.0640 - lr: 1.0000e-04 - 930ms/epoch - 5ms/step
Epoch 18/100
180/180 - 1s - loss: 60.8892 - val_loss: 44.1200 - lr: 1.0000e-04 - 888ms/epoch - 5ms/step
Epoch 19/100
180/180 - 1s - loss: 60.1340 - val_loss: 68.0694 - lr: 1.0000e-04 - 902ms/epoch - 5ms/step
Epoch 20/100
180/180 - 1s - loss: 55.1986 - val_loss: 44.6196 - lr: 1.0000e-04 - 899ms/epoch - 5ms/step
Epoch 21/100
180/180 - 1s - loss: 57.4396 - val_loss: 85.4759 - lr: 1.0000e-04 - 934ms/epoch - 5ms/step
Epoch 22/100
180/180 - 1s - loss: 55.1108 - val_loss: 62.0642 - lr: 1.0000e-04 - 875ms/epoch - 5ms/step
Epoch 23/100
180/180 - 1s - loss: 55.1789 - val_loss: 73.6033 - lr: 1.0000e-04 - 915ms/epoch - 5ms/step
Epoch 24/100
180/180 - 1s - loss: 51.7645 - val_loss: 46.4600 - lr: 1.0000e-04 - 909ms/epoch - 5ms/step
Epoch 25/100
180/180 - 1s - loss: 52.2889 - val_loss: 38.5979 - lr: 1.0000e-05 - 887ms/epoch - 5ms/step
Epoch 26/100
180/180 - 1s - loss: 52.9256 - val_loss: 36.1527 - lr: 1.0000e-05 - 903ms/epoch - 5ms/step
Epoch 27/100
180/180 - 1s - loss: 52.4452 - val_loss: 39.4368 - lr: 1.0000e-05 - 919ms/epoch - 5ms/step
Epoch 28/100
180/180 - 1s - loss: 52.2758 - val_loss: 37.5381 - lr: 1.0000e-05 - 913ms/epoch - 5ms/step
Epoch 29/100
180/180 - 1s - loss: 52.6289 - val_loss: 38.1596 - lr: 1.0000e-05 - 875ms/epoch - 5ms/step
Epoch 30/100
180/180 - 1s - loss: 53.6993 - val_loss: 42.3064 - lr: 1.0000e-05 - 876ms/epoch - 5ms/step
Epoch 31/100
180/180 - 1s - loss: 50.7391 - val_loss: 38.4957 - lr: 1.0000e-05 - 886ms/epoch - 5ms/step
Epoch 32/100
180/180 - 1s - loss: 50.3886 - val_loss: 45.2156 - lr: 1.0000e-05 - 880ms/epoch - 5ms/step
Epoch 33/100
180/180 - 1s - loss: 52.0881 - val_loss: 36.2220 - lr: 1.0000e-05 - 885ms/epoch - 5ms/step
Epoch 34/100
180/180 - 1s - loss: 53.4462 - val_loss: 37.4831 - lr: 1.0000e-05 - 876ms/epoch - 5ms/step
Epoch 35/100
180/180 - 1s - loss: 53.2242 - val_loss: 36.2767 - lr: 1.0000e-05 - 880ms/epoch - 5ms/step
Epoch 36/100
180/180 - 1s - loss: 50.5077 - val_loss: 39.5337 - lr: 1.0000e-05 - 881ms/epoch - 5ms/step
Epoch 37/100
180/180 - 1s - loss: 50.2126 - val_loss: 28.2422 - lr: 1.0000e-06 - 887ms/epoch - 5ms/step
Epoch 38/100
180/180 - 1s - loss: 50.5770 - val_loss: 28.2901 - lr: 1.0000e-06 - 909ms/epoch - 5ms/step
Epoch 39/100
180/180 - 1s - loss: 51.7723 - val_loss: 28.5813 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step
Epoch 40/100
180/180 - 1s - loss: 51.6077 - val_loss: 28.0917 - lr: 1.0000e-06 - 894ms/epoch - 5ms/step
Epoch 41/100
180/180 - 1s - loss: 50.7641 - val_loss: 28.4171 - lr: 1.0000e-06 - 900ms/epoch - 5ms/step
Epoch 42/100
180/180 - 1s - loss: 51.5076 - val_loss: 28.2041 - lr: 1.0000e-06 - 885ms/epoch - 5ms/step
Epoch 43/100
180/180 - 1s - loss: 50.7273 - val_loss: 28.4535 - lr: 1.0000e-06 - 893ms/epoch - 5ms/step
Epoch 44/100
180/180 - 1s - loss: 51.1730 - val_loss: 28.3620 - lr: 1.0000e-06 - 888ms/epoch - 5ms/step
Epoch 45/100
180/180 - 1s - loss: 54.1880 - val_loss: 28.1676 - lr: 1.0000e-06 - 882ms/epoch - 5ms/step
Epoch 46/100
180/180 - 1s - loss: 52.0131 - val_loss: 28.1627 - lr: 1.0000e-06 - 894ms/epoch - 5ms/step
Epoch 47/100
180/180 - 1s - loss: 52.4500 - val_loss: 28.0132 - lr: 1.0000e-06 - 891ms/epoch - 5ms/step
Epoch 48/100
180/180 - 1s - loss: 51.3941 - val_loss: 27.9784 - lr: 1.0000e-06 - 885ms/epoch - 5ms/step
Epoch 49/100
180/180 - 1s - loss: 51.4006 - val_loss: 27.8817 - lr: 1.0000e-06 - 956ms/epoch - 5ms/step
Epoch 50/100
180/180 - 1s - loss: 49.5279 - val_loss: 28.5968 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step
Epoch 51/100
180/180 - 1s - loss: 50.8220 - val_loss: 27.8653 - lr: 1.0000e-06 - 900ms/epoch - 5ms/step
Epoch 52/100
180/180 - 1s - loss: 50.0744 - val_loss: 27.8798 - lr: 1.0000e-06 - 891ms/epoch - 5ms/step
Epoch 53/100
180/180 - 1s - loss: 49.3442 - val_loss: 28.0141 - lr: 1.0000e-06 - 949ms/epoch - 5ms/step
Epoch 54/100
180/180 - 1s - loss: 50.8177 - val_loss: 27.7711 - lr: 1.0000e-06 - 908ms/epoch - 5ms/step
Epoch 55/100
180/180 - 1s - loss: 50.5057 - val_loss: 27.5619 - lr: 1.0000e-06 - 950ms/epoch - 5ms/step
Epoch 56/100
180/180 - 1s - loss: 50.3818 - val_loss: 27.8007 - lr: 1.0000e-06 - 907ms/epoch - 5ms/step
Epoch 57/100
180/180 - 1s - loss: 51.0867 - val_loss: 27.9196 - lr: 1.0000e-06 - 941ms/epoch - 5ms/step
Epoch 58/100
180/180 - 1s - loss: 50.8288 - val_loss: 27.6047 - lr: 1.0000e-06 - 922ms/epoch - 5ms/step
Epoch 59/100
180/180 - 1s - loss: 51.8877 - val_loss: 27.5723 - lr: 1.0000e-06 - 901ms/epoch - 5ms/step
Epoch 60/100
180/180 - 1s - loss: 52.7291 - val_loss: 27.8524 - lr: 1.0000e-06 - 906ms/epoch - 5ms/step
Epoch 61/100
180/180 - 1s - loss: 51.4810 - val_loss: 27.7864 - lr: 1.0000e-06 - 893ms/epoch - 5ms/step
Epoch 62/100
180/180 - 1s - loss: 50.9071 - val_loss: 27.9767 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step
Epoch 63/100
180/180 - 1s - loss: 50.3483 - val_loss: 27.9526 - lr: 1.0000e-06 - 921ms/epoch - 5ms/step
Epoch 64/100
180/180 - 1s - loss: 50.5335 - val_loss: 27.7428 - lr: 1.0000e-06 - 955ms/epoch - 5ms/step
Epoch 65/100
180/180 - 1s - loss: 51.9558 - val_loss: 27.6396 - lr: 1.0000e-06 - 929ms/epoch - 5ms/step
Epoch 66/100
180/180 - 1s - loss: 52.4079 - val_loss: 27.3344 - lr: 1.0000e-06 - 914ms/epoch - 5ms/step
Epoch 67/100
180/180 - 1s - loss: 50.7852 - val_loss: 27.5177 - lr: 1.0000e-06 - 953ms/epoch - 5ms/step
Epoch 68/100
180/180 - 1s - loss: 51.9260 - val_loss: 27.8164 - lr: 1.0000e-06 - 901ms/epoch - 5ms/step
Epoch 69/100
180/180 - 1s - loss: 51.0719 - val_loss: 27.2370 - lr: 1.0000e-06 - 905ms/epoch - 5ms/step
Epoch 70/100
180/180 - 1s - loss: 49.1517 - val_loss: 27.6438 - lr: 1.0000e-06 - 902ms/epoch - 5ms/step
Epoch 71/100
180/180 - 1s - loss: 50.2312 - val_loss: 27.6387 - lr: 1.0000e-06 - 897ms/epoch - 5ms/step
Epoch 72/100
180/180 - 1s - loss: 50.9665 - val_loss: 27.7089 - lr: 1.0000e-06 - 917ms/epoch - 5ms/step
Epoch 73/100
180/180 - 1s - loss: 49.8624 - val_loss: 27.6984 - lr: 1.0000e-06 - 924ms/epoch - 5ms/step
Epoch 74/100
180/180 - 1s - loss: 49.9697 - val_loss: 27.4741 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step
Epoch 75/100
180/180 - 1s - loss: 50.5036 - val_loss: 27.9261 - lr: 1.0000e-06 - 969ms/epoch - 5ms/step
Epoch 76/100
180/180 - 1s - loss: 51.3877 - val_loss: 27.4382 - lr: 1.0000e-06 - 920ms/epoch - 5ms/step
Epoch 77/100
180/180 - 1s - loss: 50.8661 - val_loss: 27.3089 - lr: 1.0000e-06 - 912ms/epoch - 5ms/step
Epoch 78/100
180/180 - 1s - loss: 52.2328 - val_loss: 27.6424 - lr: 1.0000e-06 - 950ms/epoch - 5ms/step
Epoch 79/100
180/180 - 1s - loss: 50.5149 - val_loss: 27.5051 - lr: 1.0000e-06 - 904ms/epoch - 5ms/step
Epoch 80/100
180/180 - 1s - loss: 49.7845 - val_loss: 27.6885 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step
Epoch 81/100
180/180 - 1s - loss: 50.6684 - val_loss: 28.2240 - lr: 1.0000e-06 - 961ms/epoch - 5ms/step
Epoch 82/100
180/180 - 1s - loss: 48.3892 - val_loss: 27.5653 - lr: 1.0000e-06 - 932ms/epoch - 5ms/step
Epoch 83/100
180/180 - 1s - loss: 50.5615 - val_loss: 27.0092 - lr: 1.0000e-06 - 919ms/epoch - 5ms/step
Epoch 84/100
180/180 - 1s - loss: 51.8766 - val_loss: 27.6167 - lr: 1.0000e-06 - 959ms/epoch - 5ms/step
Epoch 85/100
180/180 - 1s - loss: 49.8121 - val_loss: 27.5435 - lr: 1.0000e-06 - 918ms/epoch - 5ms/step
Epoch 86/100
180/180 - 1s - loss: 49.2227 - val_loss: 27.5950 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step
Epoch 87/100
180/180 - 1s - loss: 48.8999 - val_loss: 27.2340 - lr: 1.0000e-06 - 953ms/epoch - 5ms/step
Epoch 88/100
180/180 - 1s - loss: 52.5623 - val_loss: 27.4560 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step
Epoch 89/100
180/180 - 1s - loss: 48.5449 - val_loss: 27.5492 - lr: 1.0000e-06 - 897ms/epoch - 5ms/step
Epoch 90/100
180/180 - 1s - loss: 51.4166 - val_loss: 27.3065 - lr: 1.0000e-06 - 957ms/epoch - 5ms/step
Epoch 91/100
180/180 - 1s - loss: 50.9244 - val_loss: 27.2508 - lr: 1.0000e-06 - 957ms/epoch - 5ms/step
Epoch 92/100
180/180 - 1s - loss: 49.4454 - val_loss: 27.1114 - lr: 1.0000e-06 - 919ms/epoch - 5ms/step
Epoch 93/100
180/180 - 1s - loss: 52.1682 - val_loss: 27.3367 - lr: 1.0000e-06 - 937ms/epoch - 5ms/step
Epoch 94/100
180/180 - 1s - loss: 51.8086 - val_loss: 27.5010 - lr: 1.0000e-06 - 947ms/epoch - 5ms/step
Epoch 95/100
180/180 - 1s - loss: 49.6977 - val_loss: 27.4728 - lr: 1.0000e-06 - 941ms/epoch - 5ms/step
Epoch 96/100
180/180 - 1s - loss: 48.9664 - val_loss: 27.3309 - lr: 1.0000e-06 - 895ms/epoch - 5ms/step
Epoch 97/100
180/180 - 1s - loss: 50.3207 - val_loss: 26.9764 - lr: 1.0000e-06 - 944ms/epoch - 5ms/step
Epoch 98/100
180/180 - 1s - loss: 47.7961 - val_loss: 27.4807 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step
Epoch 99/100
180/180 - 1s - loss: 52.1776 - val_loss: 27.9628 - lr: 1.0000e-06 - 943ms/epoch - 5ms/step
Epoch 100/100
180/180 - 1s - loss: 49.4292 - val_loss: 27.1357 - lr: 1.0000e-06 - 889ms/epoch - 5ms/step
23/23 [==============================] - 0s 3ms/step - loss: 28.4436
28.443632125854492