{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Benchmark MPNN model on the Tox21 dataset (with Masked Loss)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../../../../')\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "\n", "from molgraph.chemistry.benchmark import configs\n", "from molgraph.chemistry.benchmark import tf_records\n", "from molgraph.chemistry import datasets\n", "from molgraph.losses import MaskedBinaryCrossentropy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Build **MolecularGraphEncoder**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from molgraph.chemistry import features\n", "from molgraph.chemistry import Featurizer\n", "from molgraph.chemistry import MolecularGraphEncoder\n", "\n", "atom_encoder = Featurizer([\n", " features.Symbol(),\n", " features.Hybridization(),\n", " features.FormalCharge(),\n", " features.TotalNumHs(),\n", " features.TotalValence(),\n", " features.NumRadicalElectrons(),\n", " features.Degree(),\n", " features.ChiralCenter(),\n", " features.Aromatic(),\n", " features.Ring(),\n", " features.Hetero(),\n", " features.HydrogenDonor(),\n", " features.HydrogenAcceptor(),\n", " features.CIPCode(),\n", " features.ChiralCenter(),\n", " features.RingSize(),\n", " features.Ring(),\n", " features.CrippenLogPContribution(),\n", " features.CrippenMolarRefractivityContribution(),\n", " features.TPSAContribution(),\n", " features.LabuteASAContribution(),\n", " features.GasteigerCharge(),\n", "])\n", "\n", "bond_encoder = Featurizer([\n", " features.BondType(),\n", " features.Conjugated(),\n", " features.Rotatable(),\n", " features.Ring(),\n", " features.Stereo(),\n", "])\n", "\n", "encoder = MolecularGraphEncoder(\n", " atom_encoder,\n", " bond_encoder,\n", " positional_encoding_dim=16,\n", " self_loops=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Build **TF dataset** from **MolecularGraphEncoder**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "tox21 = datasets.get('tox21')\n", "\n", "x_train = encoder(tox21['train']['x'])\n", "y_train = tox21['train']['y']\n", "y_mask_train = tox21['train']['y_mask']\n", "\n", "x_val = encoder(tox21['validation']['x'])\n", "y_val = tox21['validation']['y']\n", "y_mask_val = tox21['validation']['y_mask']\n", "\n", "x_test = encoder(tox21['test']['x'])\n", "y_test = tox21['test']['y']\n", "y_mask_test = tox21['test']['y_mask']\n", "\n", "type_spec = x_train.spec" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_train, y_train, y_mask_train))\n", " .shuffle(1024)\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "val_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_val, y_val, y_mask_val))\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "test_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_test, y_test, y_mask_test))\n", " .batch(32)\n", " .prefetch(-1)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Modeling" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "196/196 - 20s - loss: 0.2953 - roc_auc: 0.5798 - val_loss: 0.2795 - val_roc_auc: 0.6407 - lr: 1.0000e-04 - 20s/epoch - 103ms/step\n", "Epoch 2/100\n", "196/196 - 13s - loss: 0.2739 - roc_auc: 0.6427 - val_loss: 0.2414 - val_roc_auc: 0.7301 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 3/100\n", "196/196 - 13s - loss: 0.2616 - roc_auc: 0.6961 - val_loss: 0.2321 - val_roc_auc: 0.7534 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 4/100\n", "196/196 - 13s - loss: 0.2544 - roc_auc: 0.7180 - val_loss: 0.2270 - val_roc_auc: 0.7654 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 5/100\n", "196/196 - 13s - loss: 0.2513 - roc_auc: 0.7234 - val_loss: 0.2551 - val_roc_auc: 0.6889 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 6/100\n", "196/196 - 13s - loss: 0.2476 - roc_auc: 0.7372 - val_loss: 0.2193 - val_roc_auc: 0.7653 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 7/100\n", "196/196 - 13s - loss: 0.2396 - roc_auc: 0.7570 - val_loss: 0.2144 - val_roc_auc: 0.8030 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 8/100\n", "196/196 - 14s - loss: 0.2336 - roc_auc: 0.7711 - val_loss: 0.2159 - val_roc_auc: 0.8067 - lr: 1.0000e-04 - 14s/epoch - 70ms/step\n", "Epoch 9/100\n", "196/196 - 13s - loss: 0.2269 - roc_auc: 0.7851 - val_loss: 0.2066 - val_roc_auc: 0.8119 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 10/100\n", "196/196 - 13s - loss: 0.2259 - roc_auc: 0.7859 - val_loss: 0.2018 - val_roc_auc: 0.8181 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 11/100\n", "196/196 - 13s - loss: 0.2173 - roc_auc: 0.8001 - val_loss: 0.2020 - val_roc_auc: 0.8195 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 12/100\n", "196/196 - 13s - loss: 0.2129 - roc_auc: 0.8072 - val_loss: 0.1937 - val_roc_auc: 0.8251 - lr: 1.0000e-04 - 13s/epoch - 69ms/step\n", "Epoch 13/100\n", "196/196 - 13s - loss: 0.2077 - roc_auc: 0.8152 - val_loss: 0.1953 - val_roc_auc: 0.8307 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 14/100\n", "196/196 - 13s - loss: 0.2052 - roc_auc: 0.8196 - val_loss: 0.2205 - val_roc_auc: 0.7939 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 15/100\n", "196/196 - 13s - loss: 0.2074 - roc_auc: 0.8175 - val_loss: 0.1950 - val_roc_auc: 0.8337 - lr: 1.0000e-04 - 13s/epoch - 67ms/step\n", "Epoch 16/100\n", "196/196 - 13s - loss: 0.1984 - roc_auc: 0.8299 - val_loss: 0.1907 - val_roc_auc: 0.8428 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 17/100\n", "196/196 - 14s - loss: 0.1956 - roc_auc: 0.8349 - val_loss: 0.1897 - val_roc_auc: 0.8395 - lr: 1.0000e-04 - 14s/epoch - 69ms/step\n", "Epoch 18/100\n", "196/196 - 13s - loss: 0.1925 - roc_auc: 0.8373 - val_loss: 0.1957 - val_roc_auc: 0.8390 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 19/100\n", "196/196 - 13s - loss: 0.1893 - roc_auc: 0.8456 - val_loss: 0.1918 - val_roc_auc: 0.8321 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 20/100\n", "196/196 - 14s - loss: 0.1846 - roc_auc: 0.8509 - val_loss: 0.1845 - val_roc_auc: 0.8433 - lr: 1.0000e-04 - 14s/epoch - 70ms/step\n", "Epoch 21/100\n", "196/196 - 14s - loss: 0.1798 - roc_auc: 0.8556 - val_loss: 0.1935 - val_roc_auc: 0.8357 - lr: 1.0000e-04 - 14s/epoch - 73ms/step\n", "Epoch 22/100\n", "196/196 - 14s - loss: 0.1796 - roc_auc: 0.8556 - val_loss: 0.1967 - val_roc_auc: 0.8243 - lr: 1.0000e-04 - 14s/epoch - 71ms/step\n", "Epoch 23/100\n", "196/196 - 13s - loss: 0.1852 - roc_auc: 0.8493 - val_loss: 0.1884 - val_roc_auc: 0.8417 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 24/100\n", "196/196 - 13s - loss: 0.1773 - roc_auc: 0.8613 - val_loss: 0.1896 - val_roc_auc: 0.8332 - lr: 1.0000e-04 - 13s/epoch - 68ms/step\n", "Epoch 25/100\n", "196/196 - 14s - loss: 0.1738 - roc_auc: 0.8630 - val_loss: 0.1889 - val_roc_auc: 0.8374 - lr: 1.0000e-04 - 14s/epoch - 70ms/step\n", "Epoch 26/100\n", "196/196 - 13s - loss: 0.1610 - roc_auc: 0.8809 - val_loss: 0.1820 - val_roc_auc: 0.8523 - lr: 1.0000e-05 - 13s/epoch - 69ms/step\n", "Epoch 27/100\n", "196/196 - 13s - loss: 0.1568 - roc_auc: 0.8856 - val_loss: 0.1822 - val_roc_auc: 0.8463 - lr: 1.0000e-05 - 13s/epoch - 69ms/step\n", "Epoch 28/100\n", "196/196 - 13s - loss: 0.1556 - roc_auc: 0.8864 - val_loss: 0.1839 - val_roc_auc: 0.8342 - lr: 1.0000e-05 - 13s/epoch - 67ms/step\n", "Epoch 29/100\n", "196/196 - 13s - loss: 0.1546 - roc_auc: 0.8869 - val_loss: 0.1832 - val_roc_auc: 0.8461 - lr: 1.0000e-05 - 13s/epoch - 67ms/step\n", "Epoch 30/100\n", "196/196 - 13s - loss: 0.1539 - roc_auc: 0.8891 - val_loss: 0.1848 - val_roc_auc: 0.8429 - lr: 1.0000e-05 - 13s/epoch - 67ms/step\n", "Epoch 31/100\n", "196/196 - 13s - loss: 0.1522 - roc_auc: 0.8900 - val_loss: 0.1871 - val_roc_auc: 0.8400 - lr: 1.0000e-05 - 13s/epoch - 67ms/step\n", "Epoch 32/100\n", "196/196 - 13s - loss: 0.1514 - roc_auc: 0.8916 - val_loss: 0.1866 - val_roc_auc: 0.8410 - lr: 1.0000e-06 - 13s/epoch - 67ms/step\n", "Epoch 33/100\n", "196/196 - 13s - loss: 0.1509 - roc_auc: 0.8908 - val_loss: 0.1831 - val_roc_auc: 0.8476 - lr: 1.0000e-06 - 13s/epoch - 67ms/step\n", "Epoch 34/100\n", "196/196 - 13s - loss: 0.1510 - roc_auc: 0.8920 - val_loss: 0.1876 - val_roc_auc: 0.8430 - lr: 1.0000e-06 - 13s/epoch - 67ms/step\n", "Epoch 35/100\n", "196/196 - 13s - loss: 0.1507 - roc_auc: 0.8925 - val_loss: 0.1864 - val_roc_auc: 0.8394 - lr: 1.0000e-06 - 13s/epoch - 67ms/step\n", "Epoch 36/100\n", "196/196 - 13s - loss: 0.1500 - roc_auc: 0.8932 - val_loss: 0.1840 - val_roc_auc: 0.8397 - lr: 1.0000e-06 - 13s/epoch - 67ms/step\n", "25/25 [==============================] - 1s 21ms/step - loss: 0.2212 - roc_auc: 0.8150\n", "[0.22120867669582367, 0.8149662017822266]\n" ] } ], "source": [ "from molgraph.layers import MPNNConv\n", "from molgraph.layers import LaplacianPositionalEncoding\n", "from molgraph.layers import SetGatherReadout\n", "from molgraph.layers import MinMaxScaling\n", "\n", "node_preprocessing = MinMaxScaling(\n", " feature='node_feature', feature_range=(0, 1), threshold=True)\n", "edge_preprocessing = MinMaxScaling(\n", " feature='edge_feature', feature_range=(0, 1), threshold=True)\n", "\n", "node_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "\n", "model = tf.keras.Sequential([\n", " keras.layers.Input(type_spec=type_spec),\n", " node_preprocessing,\n", " edge_preprocessing,\n", " LaplacianPositionalEncoding(),\n", " MPNNConv(normalization='batch_norm'),\n", " MPNNConv(normalization='batch_norm'),\n", " MPNNConv(normalization='batch_norm'),\n", " SetGatherReadout(),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(y_train.shape[-1], 'sigmoid')\n", "])\n", "\n", "\n", "optimizer = keras.optimizers.Adam(1e-4)\n", "loss = MaskedBinaryCrossentropy(name='bce')\n", "metrics = [\n", " # AUC deals with masks\n", " keras.metrics.AUC(name='roc_auc', multi_label=True) \n", "]\n", "\n", "callbacks = [\n", " keras.callbacks.ReduceLROnPlateau(\n", " monitor='val_roc_auc',\n", " factor=0.1,\n", " patience=5,\n", " min_lr=1e-6,\n", " mode='max',\n", " ),\n", " keras.callbacks.EarlyStopping(\n", " monitor='val_roc_auc',\n", " patience=10,\n", " mode='max',\n", " restore_best_weights=True,\n", " )\n", "]\n", "\n", "model.compile(optimizer, loss, weighted_metrics=metrics)\n", "history = model.fit(\n", " train_ds, \n", " callbacks=callbacks, \n", " validation_data=val_ds, \n", " epochs=100,\n", " verbose=2,\n", ")\n", "score = model.evaluate(test_ds)\n", "print(score)" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }