{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Benchmark DTNN model on the QM7 dataset" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../../../../')\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "\n", "from molgraph.chemistry.benchmark import configs\n", "from molgraph.chemistry.benchmark import tf_records\n", "from molgraph.chemistry import datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Build **MolecularGraphEncoder3D**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from molgraph.chemistry import features\n", "from molgraph.chemistry import Featurizer\n", "from molgraph.chemistry import MolecularGraphEncoder3D\n", "\n", "atom_encoder = Featurizer([\n", " features.Symbol(),\n", " features.Hybridization(),\n", " features.FormalCharge(),\n", " features.TotalNumHs(),\n", " features.TotalValence(),\n", " features.NumRadicalElectrons(),\n", " features.Degree(),\n", " features.ChiralCenter(),\n", " features.Aromatic(),\n", " features.Ring(),\n", " features.Hetero(),\n", " features.HydrogenDonor(),\n", " features.HydrogenAcceptor(),\n", " features.CIPCode(),\n", " features.ChiralCenter(),\n", " features.RingSize(),\n", " features.Ring(),\n", " features.CrippenLogPContribution(),\n", " features.CrippenMolarRefractivityContribution(),\n", " features.TPSAContribution(),\n", " features.LabuteASAContribution(),\n", " features.GasteigerCharge(),\n", "])\n", "\n", "encoder = MolecularGraphEncoder3D(\n", " atom_encoder,\n", " conformer_generator=None, # qm7 encodes conformers\n", " edge_radius=None, # max radius\n", " coulomb=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Build **TF dataset** from **MolecularGraphEncoder3D**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "qm7 = datasets.get('qm7')\n", "\n", "x_train = encoder(qm7['train']['x'])\n", "y_train = qm7['train']['y']\n", "\n", "x_val = encoder(qm7['validation']['x'])\n", "y_val = qm7['validation']['y']\n", "\n", "x_test = encoder(qm7['test']['x'])\n", "y_test = qm7['test']['y']\n", "\n", "type_spec = x_train.spec" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", " .shuffle(1024)\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "val_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_val, y_val))\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "test_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", " .batch(32)\n", " .prefetch(-1)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Modeling" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "180/180 - 3s - loss: 1257.5872 - val_loss: 903.2003 - lr: 1.0000e-04 - 3s/epoch - 18ms/step\n", "Epoch 2/100\n", "180/180 - 1s - loss: 545.0237 - val_loss: 442.4483 - lr: 1.0000e-04 - 937ms/epoch - 5ms/step\n", "Epoch 3/100\n", "180/180 - 1s - loss: 163.0737 - val_loss: 219.5774 - lr: 1.0000e-04 - 961ms/epoch - 5ms/step\n", "Epoch 4/100\n", "180/180 - 1s - loss: 127.5755 - val_loss: 100.0004 - lr: 1.0000e-04 - 952ms/epoch - 5ms/step\n", "Epoch 5/100\n", "180/180 - 1s - loss: 120.0276 - val_loss: 72.0875 - lr: 1.0000e-04 - 924ms/epoch - 5ms/step\n", "Epoch 6/100\n", "180/180 - 1s - loss: 108.8816 - val_loss: 179.9970 - lr: 1.0000e-04 - 940ms/epoch - 5ms/step\n", "Epoch 7/100\n", "180/180 - 1s - loss: 101.8370 - val_loss: 80.3074 - lr: 1.0000e-04 - 950ms/epoch - 5ms/step\n", "Epoch 8/100\n", "180/180 - 1s - loss: 97.8691 - val_loss: 66.1778 - lr: 1.0000e-04 - 920ms/epoch - 5ms/step\n", "Epoch 9/100\n", "180/180 - 1s - loss: 92.5363 - val_loss: 60.9984 - lr: 1.0000e-04 - 939ms/epoch - 5ms/step\n", "Epoch 10/100\n", "180/180 - 1s - loss: 87.6198 - val_loss: 69.3062 - lr: 1.0000e-04 - 933ms/epoch - 5ms/step\n", "Epoch 11/100\n", "180/180 - 1s - loss: 82.5370 - val_loss: 107.2218 - lr: 1.0000e-04 - 891ms/epoch - 5ms/step\n", "Epoch 12/100\n", "180/180 - 1s - loss: 79.2450 - val_loss: 57.9821 - lr: 1.0000e-04 - 889ms/epoch - 5ms/step\n", "Epoch 13/100\n", "180/180 - 1s - loss: 78.0266 - val_loss: 73.6456 - lr: 1.0000e-04 - 887ms/epoch - 5ms/step\n", "Epoch 14/100\n", "180/180 - 1s - loss: 72.0752 - val_loss: 43.4730 - lr: 1.0000e-04 - 904ms/epoch - 5ms/step\n", "Epoch 15/100\n", "180/180 - 1s - loss: 70.5660 - val_loss: 89.5029 - lr: 1.0000e-04 - 881ms/epoch - 5ms/step\n", "Epoch 16/100\n", "180/180 - 1s - loss: 65.7411 - val_loss: 46.4776 - lr: 1.0000e-04 - 880ms/epoch - 5ms/step\n", "Epoch 17/100\n", "180/180 - 1s - loss: 62.6743 - val_loss: 92.0640 - lr: 1.0000e-04 - 930ms/epoch - 5ms/step\n", "Epoch 18/100\n", "180/180 - 1s - loss: 60.8892 - val_loss: 44.1200 - lr: 1.0000e-04 - 888ms/epoch - 5ms/step\n", "Epoch 19/100\n", "180/180 - 1s - loss: 60.1340 - val_loss: 68.0694 - lr: 1.0000e-04 - 902ms/epoch - 5ms/step\n", "Epoch 20/100\n", "180/180 - 1s - loss: 55.1986 - val_loss: 44.6196 - lr: 1.0000e-04 - 899ms/epoch - 5ms/step\n", "Epoch 21/100\n", "180/180 - 1s - loss: 57.4396 - val_loss: 85.4759 - lr: 1.0000e-04 - 934ms/epoch - 5ms/step\n", "Epoch 22/100\n", "180/180 - 1s - loss: 55.1108 - val_loss: 62.0642 - lr: 1.0000e-04 - 875ms/epoch - 5ms/step\n", "Epoch 23/100\n", "180/180 - 1s - loss: 55.1789 - val_loss: 73.6033 - lr: 1.0000e-04 - 915ms/epoch - 5ms/step\n", "Epoch 24/100\n", "180/180 - 1s - loss: 51.7645 - val_loss: 46.4600 - lr: 1.0000e-04 - 909ms/epoch - 5ms/step\n", "Epoch 25/100\n", "180/180 - 1s - loss: 52.2889 - val_loss: 38.5979 - lr: 1.0000e-05 - 887ms/epoch - 5ms/step\n", "Epoch 26/100\n", "180/180 - 1s - loss: 52.9256 - val_loss: 36.1527 - lr: 1.0000e-05 - 903ms/epoch - 5ms/step\n", "Epoch 27/100\n", "180/180 - 1s - loss: 52.4452 - val_loss: 39.4368 - lr: 1.0000e-05 - 919ms/epoch - 5ms/step\n", "Epoch 28/100\n", "180/180 - 1s - loss: 52.2758 - val_loss: 37.5381 - lr: 1.0000e-05 - 913ms/epoch - 5ms/step\n", "Epoch 29/100\n", "180/180 - 1s - loss: 52.6289 - val_loss: 38.1596 - lr: 1.0000e-05 - 875ms/epoch - 5ms/step\n", "Epoch 30/100\n", "180/180 - 1s - loss: 53.6993 - val_loss: 42.3064 - lr: 1.0000e-05 - 876ms/epoch - 5ms/step\n", "Epoch 31/100\n", "180/180 - 1s - loss: 50.7391 - val_loss: 38.4957 - lr: 1.0000e-05 - 886ms/epoch - 5ms/step\n", "Epoch 32/100\n", "180/180 - 1s - loss: 50.3886 - val_loss: 45.2156 - lr: 1.0000e-05 - 880ms/epoch - 5ms/step\n", "Epoch 33/100\n", "180/180 - 1s - loss: 52.0881 - val_loss: 36.2220 - lr: 1.0000e-05 - 885ms/epoch - 5ms/step\n", "Epoch 34/100\n", "180/180 - 1s - loss: 53.4462 - val_loss: 37.4831 - lr: 1.0000e-05 - 876ms/epoch - 5ms/step\n", "Epoch 35/100\n", "180/180 - 1s - loss: 53.2242 - val_loss: 36.2767 - lr: 1.0000e-05 - 880ms/epoch - 5ms/step\n", "Epoch 36/100\n", "180/180 - 1s - loss: 50.5077 - val_loss: 39.5337 - lr: 1.0000e-05 - 881ms/epoch - 5ms/step\n", "Epoch 37/100\n", "180/180 - 1s - loss: 50.2126 - val_loss: 28.2422 - lr: 1.0000e-06 - 887ms/epoch - 5ms/step\n", "Epoch 38/100\n", "180/180 - 1s - loss: 50.5770 - val_loss: 28.2901 - lr: 1.0000e-06 - 909ms/epoch - 5ms/step\n", "Epoch 39/100\n", "180/180 - 1s - loss: 51.7723 - val_loss: 28.5813 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step\n", "Epoch 40/100\n", "180/180 - 1s - loss: 51.6077 - val_loss: 28.0917 - lr: 1.0000e-06 - 894ms/epoch - 5ms/step\n", "Epoch 41/100\n", "180/180 - 1s - loss: 50.7641 - val_loss: 28.4171 - lr: 1.0000e-06 - 900ms/epoch - 5ms/step\n", "Epoch 42/100\n", "180/180 - 1s - loss: 51.5076 - val_loss: 28.2041 - lr: 1.0000e-06 - 885ms/epoch - 5ms/step\n", "Epoch 43/100\n", "180/180 - 1s - loss: 50.7273 - val_loss: 28.4535 - lr: 1.0000e-06 - 893ms/epoch - 5ms/step\n", "Epoch 44/100\n", "180/180 - 1s - loss: 51.1730 - val_loss: 28.3620 - lr: 1.0000e-06 - 888ms/epoch - 5ms/step\n", "Epoch 45/100\n", "180/180 - 1s - loss: 54.1880 - val_loss: 28.1676 - lr: 1.0000e-06 - 882ms/epoch - 5ms/step\n", "Epoch 46/100\n", "180/180 - 1s - loss: 52.0131 - val_loss: 28.1627 - lr: 1.0000e-06 - 894ms/epoch - 5ms/step\n", "Epoch 47/100\n", "180/180 - 1s - loss: 52.4500 - val_loss: 28.0132 - lr: 1.0000e-06 - 891ms/epoch - 5ms/step\n", "Epoch 48/100\n", "180/180 - 1s - loss: 51.3941 - val_loss: 27.9784 - lr: 1.0000e-06 - 885ms/epoch - 5ms/step\n", "Epoch 49/100\n", "180/180 - 1s - loss: 51.4006 - val_loss: 27.8817 - lr: 1.0000e-06 - 956ms/epoch - 5ms/step\n", "Epoch 50/100\n", "180/180 - 1s - loss: 49.5279 - val_loss: 28.5968 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step\n", "Epoch 51/100\n", "180/180 - 1s - loss: 50.8220 - val_loss: 27.8653 - lr: 1.0000e-06 - 900ms/epoch - 5ms/step\n", "Epoch 52/100\n", "180/180 - 1s - loss: 50.0744 - val_loss: 27.8798 - lr: 1.0000e-06 - 891ms/epoch - 5ms/step\n", "Epoch 53/100\n", "180/180 - 1s - loss: 49.3442 - val_loss: 28.0141 - lr: 1.0000e-06 - 949ms/epoch - 5ms/step\n", "Epoch 54/100\n", "180/180 - 1s - loss: 50.8177 - val_loss: 27.7711 - lr: 1.0000e-06 - 908ms/epoch - 5ms/step\n", "Epoch 55/100\n", "180/180 - 1s - loss: 50.5057 - val_loss: 27.5619 - lr: 1.0000e-06 - 950ms/epoch - 5ms/step\n", "Epoch 56/100\n", "180/180 - 1s - loss: 50.3818 - val_loss: 27.8007 - lr: 1.0000e-06 - 907ms/epoch - 5ms/step\n", "Epoch 57/100\n", "180/180 - 1s - loss: 51.0867 - val_loss: 27.9196 - lr: 1.0000e-06 - 941ms/epoch - 5ms/step\n", "Epoch 58/100\n", "180/180 - 1s - loss: 50.8288 - val_loss: 27.6047 - lr: 1.0000e-06 - 922ms/epoch - 5ms/step\n", "Epoch 59/100\n", "180/180 - 1s - loss: 51.8877 - val_loss: 27.5723 - lr: 1.0000e-06 - 901ms/epoch - 5ms/step\n", "Epoch 60/100\n", "180/180 - 1s - loss: 52.7291 - val_loss: 27.8524 - lr: 1.0000e-06 - 906ms/epoch - 5ms/step\n", "Epoch 61/100\n", "180/180 - 1s - loss: 51.4810 - val_loss: 27.7864 - lr: 1.0000e-06 - 893ms/epoch - 5ms/step\n", "Epoch 62/100\n", "180/180 - 1s - loss: 50.9071 - val_loss: 27.9767 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step\n", "Epoch 63/100\n", "180/180 - 1s - loss: 50.3483 - val_loss: 27.9526 - lr: 1.0000e-06 - 921ms/epoch - 5ms/step\n", "Epoch 64/100\n", "180/180 - 1s - loss: 50.5335 - val_loss: 27.7428 - lr: 1.0000e-06 - 955ms/epoch - 5ms/step\n", "Epoch 65/100\n", "180/180 - 1s - loss: 51.9558 - val_loss: 27.6396 - lr: 1.0000e-06 - 929ms/epoch - 5ms/step\n", "Epoch 66/100\n", "180/180 - 1s - loss: 52.4079 - val_loss: 27.3344 - lr: 1.0000e-06 - 914ms/epoch - 5ms/step\n", "Epoch 67/100\n", "180/180 - 1s - loss: 50.7852 - val_loss: 27.5177 - lr: 1.0000e-06 - 953ms/epoch - 5ms/step\n", "Epoch 68/100\n", "180/180 - 1s - loss: 51.9260 - val_loss: 27.8164 - lr: 1.0000e-06 - 901ms/epoch - 5ms/step\n", "Epoch 69/100\n", "180/180 - 1s - loss: 51.0719 - val_loss: 27.2370 - lr: 1.0000e-06 - 905ms/epoch - 5ms/step\n", "Epoch 70/100\n", "180/180 - 1s - loss: 49.1517 - val_loss: 27.6438 - lr: 1.0000e-06 - 902ms/epoch - 5ms/step\n", "Epoch 71/100\n", "180/180 - 1s - loss: 50.2312 - val_loss: 27.6387 - lr: 1.0000e-06 - 897ms/epoch - 5ms/step\n", "Epoch 72/100\n", "180/180 - 1s - loss: 50.9665 - val_loss: 27.7089 - lr: 1.0000e-06 - 917ms/epoch - 5ms/step\n", "Epoch 73/100\n", "180/180 - 1s - loss: 49.8624 - val_loss: 27.6984 - lr: 1.0000e-06 - 924ms/epoch - 5ms/step\n", "Epoch 74/100\n", "180/180 - 1s - loss: 49.9697 - val_loss: 27.4741 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step\n", "Epoch 75/100\n", "180/180 - 1s - loss: 50.5036 - val_loss: 27.9261 - lr: 1.0000e-06 - 969ms/epoch - 5ms/step\n", "Epoch 76/100\n", "180/180 - 1s - loss: 51.3877 - val_loss: 27.4382 - lr: 1.0000e-06 - 920ms/epoch - 5ms/step\n", "Epoch 77/100\n", "180/180 - 1s - loss: 50.8661 - val_loss: 27.3089 - lr: 1.0000e-06 - 912ms/epoch - 5ms/step\n", "Epoch 78/100\n", "180/180 - 1s - loss: 52.2328 - val_loss: 27.6424 - lr: 1.0000e-06 - 950ms/epoch - 5ms/step\n", "Epoch 79/100\n", "180/180 - 1s - loss: 50.5149 - val_loss: 27.5051 - lr: 1.0000e-06 - 904ms/epoch - 5ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 80/100\n", "180/180 - 1s - loss: 49.7845 - val_loss: 27.6885 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step\n", "Epoch 81/100\n", "180/180 - 1s - loss: 50.6684 - val_loss: 28.2240 - lr: 1.0000e-06 - 961ms/epoch - 5ms/step\n", "Epoch 82/100\n", "180/180 - 1s - loss: 48.3892 - val_loss: 27.5653 - lr: 1.0000e-06 - 932ms/epoch - 5ms/step\n", "Epoch 83/100\n", "180/180 - 1s - loss: 50.5615 - val_loss: 27.0092 - lr: 1.0000e-06 - 919ms/epoch - 5ms/step\n", "Epoch 84/100\n", "180/180 - 1s - loss: 51.8766 - val_loss: 27.6167 - lr: 1.0000e-06 - 959ms/epoch - 5ms/step\n", "Epoch 85/100\n", "180/180 - 1s - loss: 49.8121 - val_loss: 27.5435 - lr: 1.0000e-06 - 918ms/epoch - 5ms/step\n", "Epoch 86/100\n", "180/180 - 1s - loss: 49.2227 - val_loss: 27.5950 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step\n", "Epoch 87/100\n", "180/180 - 1s - loss: 48.8999 - val_loss: 27.2340 - lr: 1.0000e-06 - 953ms/epoch - 5ms/step\n", "Epoch 88/100\n", "180/180 - 1s - loss: 52.5623 - val_loss: 27.4560 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step\n", "Epoch 89/100\n", "180/180 - 1s - loss: 48.5449 - val_loss: 27.5492 - lr: 1.0000e-06 - 897ms/epoch - 5ms/step\n", "Epoch 90/100\n", "180/180 - 1s - loss: 51.4166 - val_loss: 27.3065 - lr: 1.0000e-06 - 957ms/epoch - 5ms/step\n", "Epoch 91/100\n", "180/180 - 1s - loss: 50.9244 - val_loss: 27.2508 - lr: 1.0000e-06 - 957ms/epoch - 5ms/step\n", "Epoch 92/100\n", "180/180 - 1s - loss: 49.4454 - val_loss: 27.1114 - lr: 1.0000e-06 - 919ms/epoch - 5ms/step\n", "Epoch 93/100\n", "180/180 - 1s - loss: 52.1682 - val_loss: 27.3367 - lr: 1.0000e-06 - 937ms/epoch - 5ms/step\n", "Epoch 94/100\n", "180/180 - 1s - loss: 51.8086 - val_loss: 27.5010 - lr: 1.0000e-06 - 947ms/epoch - 5ms/step\n", "Epoch 95/100\n", "180/180 - 1s - loss: 49.6977 - val_loss: 27.4728 - lr: 1.0000e-06 - 941ms/epoch - 5ms/step\n", "Epoch 96/100\n", "180/180 - 1s - loss: 48.9664 - val_loss: 27.3309 - lr: 1.0000e-06 - 895ms/epoch - 5ms/step\n", "Epoch 97/100\n", "180/180 - 1s - loss: 50.3207 - val_loss: 26.9764 - lr: 1.0000e-06 - 944ms/epoch - 5ms/step\n", "Epoch 98/100\n", "180/180 - 1s - loss: 47.7961 - val_loss: 27.4807 - lr: 1.0000e-06 - 903ms/epoch - 5ms/step\n", "Epoch 99/100\n", "180/180 - 1s - loss: 52.1776 - val_loss: 27.9628 - lr: 1.0000e-06 - 943ms/epoch - 5ms/step\n", "Epoch 100/100\n", "180/180 - 1s - loss: 49.4292 - val_loss: 27.1357 - lr: 1.0000e-06 - 889ms/epoch - 5ms/step\n", "23/23 [==============================] - 0s 3ms/step - loss: 28.4436\n", "28.443632125854492\n" ] } ], "source": [ "from molgraph.layers import DTNNConv\n", "from molgraph.layers import Readout\n", "from molgraph.layers import MinMaxScaling\n", "\n", "node_preprocessing = MinMaxScaling(\n", " feature='node_feature', feature_range=(0, 1), threshold=True)\n", "edge_preprocessing = MinMaxScaling(\n", " feature='edge_feature', feature_range=(0, 1), threshold=True)\n", "\n", "node_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "\n", "model = tf.keras.Sequential([\n", " keras.layers.Input(type_spec=type_spec),\n", " node_preprocessing,\n", " edge_preprocessing,\n", " DTNNConv(normalization='batch_norm'),\n", " DTNNConv(normalization='batch_norm'),\n", " DTNNConv(normalization='batch_norm'),\n", " Readout(),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(y_train.shape[-1])\n", "])\n", "\n", "\n", "optimizer = keras.optimizers.Adam(1e-4)\n", "loss = keras.losses.MeanAbsoluteError(name='mae')\n", "\n", "callbacks = [\n", " keras.callbacks.ReduceLROnPlateau(\n", " monitor='val_loss',\n", " factor=0.1,\n", " patience=10,\n", " min_lr=1e-6,\n", " mode='min',\n", " ),\n", " keras.callbacks.EarlyStopping(\n", " monitor='val_loss',\n", " patience=20,\n", " mode='min',\n", " restore_best_weights=True,\n", " )\n", "]\n", "\n", "model.compile(optimizer, loss)\n", "history = model.fit(\n", " train_ds, \n", " callbacks=callbacks, \n", " validation_data=val_ds, \n", " epochs=100,\n", " verbose=2,\n", ")\n", "score = model.evaluate(test_ds)\n", "print(score)" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }