{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Benchmark GAT model on the ESOL dataset" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../../../../')\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "\n", "from molgraph.chemistry.benchmark import configs\n", "from molgraph.chemistry.benchmark import tf_records\n", "from molgraph.chemistry import datasets\n", "from molgraph.losses import masked_losses\n", "from molgraph.metrics import masked_metrics\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Build **MolecularGraphEncoder**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from molgraph.chemistry import features\n", "from molgraph.chemistry import Featurizer\n", "from molgraph.chemistry import MolecularGraphEncoder\n", "\n", "atom_encoder = Featurizer([\n", " features.Symbol(),\n", " features.Hybridization(),\n", " features.FormalCharge(),\n", " features.TotalNumHs(),\n", " features.TotalValence(),\n", " features.NumRadicalElectrons(),\n", " features.Degree(),\n", " features.ChiralCenter(),\n", " features.Aromatic(),\n", " features.Ring(),\n", " features.Hetero(),\n", " features.HydrogenDonor(),\n", " features.HydrogenAcceptor(),\n", " features.CIPCode(),\n", " features.ChiralCenter(),\n", " features.RingSize(),\n", " features.Ring(),\n", " features.CrippenLogPContribution(),\n", " features.CrippenMolarRefractivityContribution(),\n", " features.TPSAContribution(),\n", " features.LabuteASAContribution(),\n", " features.GasteigerCharge(),\n", "])\n", "\n", "bond_encoder = Featurizer([\n", " features.BondType(),\n", " features.Conjugated(),\n", " features.Rotatable(),\n", " features.Ring(),\n", " features.Stereo(),\n", "])\n", "\n", "encoder = MolecularGraphEncoder(\n", " atom_encoder,\n", " bond_encoder,\n", " positional_encoding_dim=16,\n", " self_loops=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Build **TF dataset** from **MolecularGraphEncoder**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "esol = datasets.get('esol')\n", "\n", "x_train = encoder(esol['train']['x'])\n", "y_train = esol['train']['y']\n", "\n", "x_val = encoder(esol['validation']['x'])\n", "y_val = esol['validation']['y']\n", "\n", "x_test = encoder(esol['test']['x'])\n", "y_test = esol['test']['y']\n", "\n", "type_spec = x_train.spec" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", " .shuffle(1024)\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "val_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_val, y_val))\n", " .batch(32)\n", " .prefetch(-1)\n", ")\n", "\n", "test_ds = (\n", " tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", " .batch(32)\n", " .prefetch(-1)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Modeling" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "29/29 - 6s - loss: 1.6104 - val_loss: 2.8727 - lr: 1.0000e-04 - 6s/epoch - 209ms/step\n", "Epoch 2/100\n", "29/29 - 0s - loss: 1.0733 - val_loss: 2.8306 - lr: 1.0000e-04 - 388ms/epoch - 13ms/step\n", "Epoch 3/100\n", "29/29 - 0s - loss: 0.9292 - val_loss: 2.7933 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step\n", "Epoch 4/100\n", "29/29 - 0s - loss: 0.7980 - val_loss: 2.6992 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step\n", "Epoch 5/100\n", "29/29 - 0s - loss: 0.7400 - val_loss: 2.6170 - lr: 1.0000e-04 - 386ms/epoch - 13ms/step\n", "Epoch 6/100\n", "29/29 - 0s - loss: 0.6782 - val_loss: 2.5222 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step\n", "Epoch 7/100\n", "29/29 - 0s - loss: 0.6756 - val_loss: 2.3882 - lr: 1.0000e-04 - 392ms/epoch - 14ms/step\n", "Epoch 8/100\n", "29/29 - 0s - loss: 0.6457 - val_loss: 2.2180 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step\n", "Epoch 9/100\n", "29/29 - 0s - loss: 0.6356 - val_loss: 2.1035 - lr: 1.0000e-04 - 379ms/epoch - 13ms/step\n", "Epoch 10/100\n", "29/29 - 0s - loss: 0.5921 - val_loss: 2.0029 - lr: 1.0000e-04 - 397ms/epoch - 14ms/step\n", "Epoch 11/100\n", "29/29 - 0s - loss: 0.6023 - val_loss: 1.7675 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step\n", "Epoch 12/100\n", "29/29 - 0s - loss: 0.5469 - val_loss: 1.6520 - lr: 1.0000e-04 - 396ms/epoch - 14ms/step\n", "Epoch 13/100\n", "29/29 - 0s - loss: 0.5428 - val_loss: 1.4848 - lr: 1.0000e-04 - 397ms/epoch - 14ms/step\n", "Epoch 14/100\n", "29/29 - 0s - loss: 0.5678 - val_loss: 1.3723 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step\n", "Epoch 15/100\n", "29/29 - 0s - loss: 0.5733 - val_loss: 1.2860 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step\n", "Epoch 16/100\n", "29/29 - 0s - loss: 0.5791 - val_loss: 1.2071 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step\n", "Epoch 17/100\n", "29/29 - 0s - loss: 0.5403 - val_loss: 1.1961 - lr: 1.0000e-04 - 393ms/epoch - 14ms/step\n", "Epoch 18/100\n", "29/29 - 0s - loss: 0.5152 - val_loss: 1.1848 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step\n", "Epoch 19/100\n", "29/29 - 0s - loss: 0.5947 - val_loss: 1.0619 - lr: 1.0000e-04 - 404ms/epoch - 14ms/step\n", "Epoch 20/100\n", "29/29 - 0s - loss: 0.5636 - val_loss: 0.9673 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step\n", "Epoch 21/100\n", "29/29 - 0s - loss: 0.4985 - val_loss: 1.4409 - lr: 1.0000e-04 - 374ms/epoch - 13ms/step\n", "Epoch 22/100\n", "29/29 - 0s - loss: 0.5274 - val_loss: 1.0172 - lr: 1.0000e-04 - 376ms/epoch - 13ms/step\n", "Epoch 23/100\n", "29/29 - 0s - loss: 0.5385 - val_loss: 0.8310 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step\n", "Epoch 24/100\n", "29/29 - 0s - loss: 0.4996 - val_loss: 0.7225 - lr: 1.0000e-04 - 392ms/epoch - 14ms/step\n", "Epoch 25/100\n", "29/29 - 0s - loss: 0.5324 - val_loss: 0.5973 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step\n", "Epoch 26/100\n", "29/29 - 0s - loss: 0.4900 - val_loss: 0.6436 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step\n", "Epoch 27/100\n", "29/29 - 0s - loss: 0.4709 - val_loss: 0.6613 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step\n", "Epoch 28/100\n", "29/29 - 0s - loss: 0.4714 - val_loss: 0.6028 - lr: 1.0000e-04 - 381ms/epoch - 13ms/step\n", "Epoch 29/100\n", "29/29 - 0s - loss: 0.4654 - val_loss: 0.6892 - lr: 1.0000e-04 - 387ms/epoch - 13ms/step\n", "Epoch 30/100\n", "29/29 - 0s - loss: 0.4792 - val_loss: 0.5300 - lr: 1.0000e-04 - 389ms/epoch - 13ms/step\n", "Epoch 31/100\n", "29/29 - 0s - loss: 0.4470 - val_loss: 0.6892 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step\n", "Epoch 32/100\n", "29/29 - 0s - loss: 0.4791 - val_loss: 0.5954 - lr: 1.0000e-04 - 375ms/epoch - 13ms/step\n", "Epoch 33/100\n", "29/29 - 0s - loss: 0.4156 - val_loss: 0.4791 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step\n", "Epoch 34/100\n", "29/29 - 0s - loss: 0.4756 - val_loss: 0.7215 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step\n", "Epoch 35/100\n", "29/29 - 0s - loss: 0.4974 - val_loss: 0.5423 - lr: 1.0000e-04 - 378ms/epoch - 13ms/step\n", "Epoch 36/100\n", "29/29 - 0s - loss: 0.4495 - val_loss: 0.5344 - lr: 1.0000e-04 - 381ms/epoch - 13ms/step\n", "Epoch 37/100\n", "29/29 - 0s - loss: 0.5167 - val_loss: 0.8142 - lr: 1.0000e-04 - 374ms/epoch - 13ms/step\n", "Epoch 38/100\n", "29/29 - 0s - loss: 0.4763 - val_loss: 0.5238 - lr: 1.0000e-04 - 376ms/epoch - 13ms/step\n", "Epoch 39/100\n", "29/29 - 0s - loss: 0.4450 - val_loss: 0.5957 - lr: 1.0000e-04 - 382ms/epoch - 13ms/step\n", "Epoch 40/100\n", "29/29 - 0s - loss: 0.4635 - val_loss: 0.5614 - lr: 1.0000e-04 - 384ms/epoch - 13ms/step\n", "Epoch 41/100\n", "29/29 - 0s - loss: 0.4348 - val_loss: 0.4831 - lr: 1.0000e-04 - 387ms/epoch - 13ms/step\n", "Epoch 42/100\n", "29/29 - 0s - loss: 0.3967 - val_loss: 0.5873 - lr: 1.0000e-04 - 394ms/epoch - 14ms/step\n", "Epoch 43/100\n", "29/29 - 0s - loss: 0.4132 - val_loss: 0.5240 - lr: 1.0000e-04 - 372ms/epoch - 13ms/step\n", "Epoch 44/100\n", "29/29 - 0s - loss: 0.3935 - val_loss: 0.4802 - lr: 1.0000e-05 - 375ms/epoch - 13ms/step\n", "Epoch 45/100\n", "29/29 - 0s - loss: 0.3535 - val_loss: 0.4594 - lr: 1.0000e-05 - 388ms/epoch - 13ms/step\n", "Epoch 46/100\n", "29/29 - 0s - loss: 0.3636 - val_loss: 0.4413 - lr: 1.0000e-05 - 393ms/epoch - 14ms/step\n", "Epoch 47/100\n", "29/29 - 0s - loss: 0.3381 - val_loss: 0.4438 - lr: 1.0000e-05 - 383ms/epoch - 13ms/step\n", "Epoch 48/100\n", "29/29 - 0s - loss: 0.3362 - val_loss: 0.4697 - lr: 1.0000e-05 - 378ms/epoch - 13ms/step\n", "Epoch 49/100\n", "29/29 - 0s - loss: 0.3738 - val_loss: 0.4590 - lr: 1.0000e-05 - 373ms/epoch - 13ms/step\n", "Epoch 50/100\n", "29/29 - 0s - loss: 0.3348 - val_loss: 0.4344 - lr: 1.0000e-05 - 388ms/epoch - 13ms/step\n", "Epoch 51/100\n", "29/29 - 0s - loss: 0.3483 - val_loss: 0.4395 - lr: 1.0000e-05 - 382ms/epoch - 13ms/step\n", "Epoch 52/100\n", "29/29 - 0s - loss: 0.3833 - val_loss: 0.4444 - lr: 1.0000e-05 - 384ms/epoch - 13ms/step\n", "Epoch 53/100\n", "29/29 - 0s - loss: 0.3380 - val_loss: 0.4358 - lr: 1.0000e-05 - 373ms/epoch - 13ms/step\n", "Epoch 54/100\n", "29/29 - 0s - loss: 0.3517 - val_loss: 0.4578 - lr: 1.0000e-05 - 378ms/epoch - 13ms/step\n", "Epoch 55/100\n", "29/29 - 0s - loss: 0.3465 - val_loss: 0.4576 - lr: 1.0000e-05 - 374ms/epoch - 13ms/step\n", "Epoch 56/100\n", "29/29 - 0s - loss: 0.3377 - val_loss: 0.4544 - lr: 1.0000e-05 - 387ms/epoch - 13ms/step\n", "Epoch 57/100\n", "29/29 - 0s - loss: 0.3469 - val_loss: 0.4457 - lr: 1.0000e-05 - 384ms/epoch - 13ms/step\n", "Epoch 58/100\n", "29/29 - 0s - loss: 0.3298 - val_loss: 0.4854 - lr: 1.0000e-05 - 372ms/epoch - 13ms/step\n", "Epoch 59/100\n", "29/29 - 0s - loss: 0.3758 - val_loss: 0.4491 - lr: 1.0000e-05 - 372ms/epoch - 13ms/step\n", "Epoch 60/100\n", "29/29 - 0s - loss: 0.3519 - val_loss: 0.4462 - lr: 1.0000e-05 - 375ms/epoch - 13ms/step\n", "Epoch 61/100\n", "29/29 - 0s - loss: 0.3234 - val_loss: 0.4383 - lr: 1.0000e-06 - 388ms/epoch - 13ms/step\n", "Epoch 62/100\n", "29/29 - 0s - loss: 0.3429 - val_loss: 0.4395 - lr: 1.0000e-06 - 387ms/epoch - 13ms/step\n", "Epoch 63/100\n", "29/29 - 0s - loss: 0.3258 - val_loss: 0.4392 - lr: 1.0000e-06 - 386ms/epoch - 13ms/step\n", "Epoch 64/100\n", "29/29 - 0s - loss: 0.3527 - val_loss: 0.4561 - lr: 1.0000e-06 - 414ms/epoch - 14ms/step\n", "Epoch 65/100\n", "29/29 - 0s - loss: 0.3089 - val_loss: 0.4505 - lr: 1.0000e-06 - 375ms/epoch - 13ms/step\n", "Epoch 66/100\n", "29/29 - 0s - loss: 0.3431 - val_loss: 0.4419 - lr: 1.0000e-06 - 375ms/epoch - 13ms/step\n", "Epoch 67/100\n", "29/29 - 0s - loss: 0.3460 - val_loss: 0.4373 - lr: 1.0000e-06 - 385ms/epoch - 13ms/step\n", "Epoch 68/100\n", "29/29 - 0s - loss: 0.3568 - val_loss: 0.4480 - lr: 1.0000e-06 - 387ms/epoch - 13ms/step\n", "Epoch 69/100\n", "29/29 - 0s - loss: 0.3029 - val_loss: 0.4462 - lr: 1.0000e-06 - 383ms/epoch - 13ms/step\n", "Epoch 70/100\n", "29/29 - 0s - loss: 0.3390 - val_loss: 0.4427 - lr: 1.0000e-06 - 390ms/epoch - 13ms/step\n", "4/4 [==============================] - 0s 4ms/step - loss: 0.4317\n", "0.4317157566547394\n" ] } ], "source": [ "from molgraph.layers import GATConv\n", "from molgraph.layers import LaplacianPositionalEncoding\n", "from molgraph.layers import Readout\n", "from molgraph.layers import MinMaxScaling\n", "\n", "node_preprocessing = MinMaxScaling(\n", " feature='node_feature', feature_range=(0, 1), threshold=True)\n", "edge_preprocessing = MinMaxScaling(\n", " feature='edge_feature', feature_range=(0, 1), threshold=True)\n", "\n", "node_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))\n", "\n", "model = tf.keras.Sequential([\n", " keras.layers.Input(type_spec=type_spec),\n", " node_preprocessing,\n", " edge_preprocessing,\n", " LaplacianPositionalEncoding(),\n", " GATConv(normalization='batch_norm'),\n", " GATConv(normalization='batch_norm'),\n", " GATConv(normalization='batch_norm'),\n", " Readout(),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(1024, 'relu'),\n", " keras.layers.Dense(y_train.shape[-1])\n", "])\n", "\n", "\n", "optimizer = keras.optimizers.Adam(1e-4)\n", "loss = keras.losses.MeanAbsoluteError(name='mae')\n", "callbacks = [\n", " keras.callbacks.ReduceLROnPlateau(\n", " monitor=f'val_loss',\n", " factor=0.1,\n", " patience=10,\n", " min_lr=1e-6,\n", " mode='min',\n", " ),\n", " keras.callbacks.EarlyStopping(\n", " monitor=f'val_loss',\n", " patience=20,\n", " mode='min',\n", " restore_best_weights=True,\n", " )\n", "]\n", "\n", "model.compile(optimizer, loss)\n", "history = model.fit(\n", " train_ds, \n", " callbacks=callbacks, \n", " validation_data=val_ds, \n", " epochs=100,\n", " verbose=2,\n", ")\n", "score = model.evaluate(test_ds)\n", "print(score)" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }