{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Layers and models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from molgraph.chemistry import MolecularGraphEncoder\n", "from molgraph.chemistry import Featurizer \n", "from molgraph.chemistry import features\n", "\n", "import tensorflow as tf\n", "\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Construct a `MolecularGraphEncoder`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [], "source": [ "atom_encoder = Featurizer([\n", " features.Symbol({'C', 'N', 'O'}, oov_size=1),\n", " features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),\n", " features.HydrogenDonor(),\n", " features.HydrogenAcceptor(),\n", " features.Hetero()\n", "])\n", "\n", "bond_encoder = Featurizer([\n", " features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),\n", " features.Rotatable(),\n", "])\n", "\n", "encoder = MolecularGraphEncoder(atom_encoder, bond_encoder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obtain dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Compound IDESOL predicted log solubility in mols per litreMinimum DegreeMolecular WeightNumber of H-Bond DonorsNumber of RingsNumber of Rotatable BondsPolar Surface Areameasured log solubility in mols per litresmiles
0Amigdalin-0.9741457.432737202.32-0.77OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...
1Fenfuram-2.8851201.22512242.24-3.30Cc1occc1C(=O)Nc2ccccc2
2citral-2.5791152.23700417.07-2.06CC(C)=CCCC(C)=CC(=O)
\n", "
" ], "text/plain": [ " Compound ID ESOL predicted log solubility in mols per litre \\\n", "0 Amigdalin -0.974 \n", "1 Fenfuram -2.885 \n", "2 citral -2.579 \n", "\n", " Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings \\\n", "0 1 457.432 7 3 \n", "1 1 201.225 1 2 \n", "2 1 152.237 0 0 \n", "\n", " Number of Rotatable Bonds Polar Surface Area \\\n", "0 7 202.32 \n", "1 2 42.24 \n", "2 4 17.07 \n", "\n", " measured log solubility in mols per litre \\\n", "0 -0.77 \n", "1 -3.30 \n", "2 -2.06 \n", "\n", " smiles \n", "0 OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)... \n", "1 Cc1occc1C(=O)Nc2ccccc2 \n", "2 CC(C)=CCCC(C)=CC(=O) " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = tf.keras.utils.get_file(\n", " fname='ESOL.csv',\n", " origin='http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/ESOL.csv',\n", ")\n", "df = pd.read_csv(path)\n", "df.head(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obtain SMILES `x`and associated labels `y`" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "x, y = df['smiles'].values, df['measured log solubility in mols per litre'].values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obtain `GraphTensor` from `x`, via `MolecularGraphEncoder`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=,\n", " node_position=)\n", "\n", "node_feature shape: (14991, 11)\n", "edge_dst shape: (30856,)\n", "edge_src shape: (30856,)\n", "edge_feature shape: (30856, 5)\n" ] } ], "source": [ "x = encoder(x)\n", "\n", "print(x, end='\\n\\n')\n", "print('node_feature shape:', x.node_feature.shape)\n", "print('edge_dst shape: ', x.edge_dst.shape)\n", "print('edge_src shape: ', x.edge_src.shape)\n", "print('edge_feature shape:', x.edge_feature.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Import GNN **layers**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from molgraph import layers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Use GNN **layers**" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=,\n", " node_position=)\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=,\n", " node_position=)\n" ] } ], "source": [ "layer = layers.GATConv(units=128, use_edge_features=True)\n", "\n", "out1 = layer(x.separate()) # with nested ragged tensors\n", "out2 = layer(x) # with nested tensors\n", "\n", "print(out1, end='\\n\\n')\n", "print(out2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Pass GNN **layers** to **Keras models**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split data into train/test" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "random_indices = np.random.permutation(np.arange(x.shape[0]))\n", "\n", "x_train = x[random_indices[:800]]\n", "x_test = x[random_indices[800:]]\n", "\n", "y_train = y[random_indices[:800]]\n", "y_test = y[random_indices[800:]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Option 1: Keras Sequential API" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " gin_conv (GINConv) (None, None, 128) 21187 \n", " \n", " gin_conv_1 (GINConv) (None, None, 128) 51073 \n", " \n", " gin_conv_2 (GINConv) (None, None, 128) 49537 \n", " \n", " segment_pooling_readout (S (None, 128) 0 \n", " egmentPoolingReadout) \n", " \n", " dense_2 (Dense) (None, 512) 66048 \n", " \n", " dense_3 (Dense) (None, 1) 513 \n", " \n", "=================================================================\n", "Total params: 188358 (735.77 KB)\n", "Trainable params: 188358 (735.77 KB)\n", "Non-trainable params: 0 (0.00 Byte)\n", "_________________________________________________________________\n" ] } ], "source": [ "sequential_model = tf.keras.Sequential([\n", " tf.keras.layers.Input(type_spec=x_train.spec),\n", " layers.GINConv(128),\n", " layers.GINConv(128),\n", " layers.GINConv(128),\n", " layers.Readout(),\n", " tf.keras.layers.Dense(512, activation='relu'),\n", " tf.keras.layers.Dense(1)\n", "])\n", "\n", "sequential_model.summary()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "25/25 - 4s - loss: 7.0829 - mae: 2.0933 - 4s/epoch - 167ms/step\n", "Epoch 2/30\n", "25/25 - 0s - loss: 3.7235 - mae: 1.5299 - 138ms/epoch - 6ms/step\n", "Epoch 3/30\n", "25/25 - 0s - loss: 3.0047 - mae: 1.3911 - 138ms/epoch - 6ms/step\n", "Epoch 4/30\n", "25/25 - 0s - loss: 2.9660 - mae: 1.3739 - 153ms/epoch - 6ms/step\n", "Epoch 5/30\n", "25/25 - 0s - loss: 2.7608 - mae: 1.3286 - 152ms/epoch - 6ms/step\n", "Epoch 6/30\n", "25/25 - 0s - loss: 2.5752 - mae: 1.2790 - 140ms/epoch - 6ms/step\n", "Epoch 7/30\n", "25/25 - 0s - loss: 2.3623 - mae: 1.2191 - 123ms/epoch - 5ms/step\n", "Epoch 8/30\n", "25/25 - 0s - loss: 2.2896 - mae: 1.1918 - 130ms/epoch - 5ms/step\n", "Epoch 9/30\n", "25/25 - 0s - loss: 2.2290 - mae: 1.1800 - 128ms/epoch - 5ms/step\n", "Epoch 10/30\n", "25/25 - 0s - loss: 1.7473 - mae: 1.0566 - 128ms/epoch - 5ms/step\n", "Epoch 11/30\n", "25/25 - 0s - loss: 1.7529 - mae: 1.0464 - 138ms/epoch - 6ms/step\n", "Epoch 12/30\n", "25/25 - 0s - loss: 1.6870 - mae: 1.0037 - 131ms/epoch - 5ms/step\n", "Epoch 13/30\n", "25/25 - 0s - loss: 1.5346 - mae: 0.9839 - 133ms/epoch - 5ms/step\n", "Epoch 14/30\n", "25/25 - 0s - loss: 1.4849 - mae: 0.9538 - 142ms/epoch - 6ms/step\n", "Epoch 15/30\n", "25/25 - 0s - loss: 1.2754 - mae: 0.8829 - 133ms/epoch - 5ms/step\n", "Epoch 16/30\n", "25/25 - 0s - loss: 1.2463 - mae: 0.8781 - 143ms/epoch - 6ms/step\n", "Epoch 17/30\n", "25/25 - 0s - loss: 1.3979 - mae: 0.9110 - 134ms/epoch - 5ms/step\n", "Epoch 18/30\n", "25/25 - 0s - loss: 1.2012 - mae: 0.8581 - 136ms/epoch - 5ms/step\n", "Epoch 19/30\n", "25/25 - 0s - loss: 1.0255 - mae: 0.8057 - 144ms/epoch - 6ms/step\n", "Epoch 20/30\n", "25/25 - 0s - loss: 1.0765 - mae: 0.8157 - 141ms/epoch - 6ms/step\n", "Epoch 21/30\n", "25/25 - 0s - loss: 1.1615 - mae: 0.8261 - 132ms/epoch - 5ms/step\n", "Epoch 22/30\n", "25/25 - 0s - loss: 1.0205 - mae: 0.7694 - 138ms/epoch - 6ms/step\n", "Epoch 23/30\n", "25/25 - 0s - loss: 1.1099 - mae: 0.8121 - 139ms/epoch - 6ms/step\n", "Epoch 24/30\n", "25/25 - 0s - loss: 0.8393 - mae: 0.7029 - 138ms/epoch - 6ms/step\n", "Epoch 25/30\n", "25/25 - 0s - loss: 0.8663 - mae: 0.7098 - 143ms/epoch - 6ms/step\n", "Epoch 26/30\n", "25/25 - 0s - loss: 0.9310 - mae: 0.7355 - 141ms/epoch - 6ms/step\n", "Epoch 27/30\n", "25/25 - 0s - loss: 0.8394 - mae: 0.6957 - 140ms/epoch - 6ms/step\n", "Epoch 28/30\n", "25/25 - 0s - loss: 1.1528 - mae: 0.8242 - 143ms/epoch - 6ms/step\n", "Epoch 29/30\n", "25/25 - 0s - loss: 0.9932 - mae: 0.7647 - 129ms/epoch - 5ms/step\n", "Epoch 30/30\n", "25/25 - 0s - loss: 0.8268 - mae: 0.6887 - 129ms/epoch - 5ms/step\n", "11/11 [==============================] - 0s 2ms/step - loss: 0.9816 - mae: 0.7799\n", "mse = 0.982\n", "mae = 0.780\n" ] } ], "source": [ "sequential_model.compile('adam', 'mse', ['mae'])\n", "sequential_model.fit(x_train, y_train, epochs=30, verbose=2)\n", "mse, mae = sequential_model.evaluate(x_test, y_test)\n", "print(f\"{mse = :.3f}\\n{mae = :.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Option 2: Keras Functional API" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " input_2 (InputLayer) [(None, None, 11)] 0 \n", " \n", " gin_conv_3 (GINConv) (None, None, 128) 21187 \n", " \n", " gin_conv_4 (GINConv) (None, None, 128) 51073 \n", " \n", " gin_conv_5 (GINConv) (None, None, 128) 49537 \n", " \n", " segment_pooling_readout_1 (None, 128) 0 \n", " (SegmentPoolingReadout) \n", " \n", " dense_4 (Dense) (None, 512) 66048 \n", " \n", " dense_5 (Dense) (None, 1) 513 \n", " \n", "=================================================================\n", "Total params: 188358 (735.77 KB)\n", "Trainable params: 188358 (735.77 KB)\n", "Non-trainable params: 0 (0.00 Byte)\n", "_________________________________________________________________\n" ] } ], "source": [ "inputs = tf.keras.layers.Input(type_spec=x_train.spec)\n", "x = layers.GINConv(128)(inputs)\n", "x = layers.GINConv(128)(x)\n", "x = layers.GINConv(128)(x)\n", "x = layers.Readout()(x)\n", "x = tf.keras.layers.Dense(512, activation='relu')(x)\n", "x = tf.keras.layers.Dense(1)(x)\n", "functional_model = tf.keras.Model(inputs=inputs, outputs=x)\n", "functional_model.summary()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "25/25 - 4s - loss: 6.9203 - mae: 2.0325 - 4s/epoch - 148ms/step\n", "Epoch 2/30\n", "25/25 - 0s - loss: 3.4010 - mae: 1.4839 - 133ms/epoch - 5ms/step\n", "Epoch 3/30\n", "25/25 - 0s - loss: 2.9173 - mae: 1.3693 - 148ms/epoch - 6ms/step\n", "Epoch 4/30\n", "25/25 - 0s - loss: 2.7455 - mae: 1.3209 - 146ms/epoch - 6ms/step\n", "Epoch 5/30\n", "25/25 - 0s - loss: 2.8170 - mae: 1.3229 - 140ms/epoch - 6ms/step\n", "Epoch 6/30\n", "25/25 - 0s - loss: 2.4958 - mae: 1.2679 - 144ms/epoch - 6ms/step\n", "Epoch 7/30\n", "25/25 - 0s - loss: 2.3044 - mae: 1.1992 - 144ms/epoch - 6ms/step\n", "Epoch 8/30\n", "25/25 - 0s - loss: 1.9932 - mae: 1.1212 - 139ms/epoch - 6ms/step\n", "Epoch 9/30\n", "25/25 - 0s - loss: 1.8481 - mae: 1.0773 - 141ms/epoch - 6ms/step\n", "Epoch 10/30\n", "25/25 - 0s - loss: 1.7411 - mae: 1.0312 - 143ms/epoch - 6ms/step\n", "Epoch 11/30\n", "25/25 - 0s - loss: 1.5762 - mae: 0.9794 - 140ms/epoch - 6ms/step\n", "Epoch 12/30\n", "25/25 - 0s - loss: 1.5218 - mae: 0.9651 - 146ms/epoch - 6ms/step\n", "Epoch 13/30\n", "25/25 - 0s - loss: 1.7555 - mae: 1.0286 - 143ms/epoch - 6ms/step\n", "Epoch 14/30\n", "25/25 - 0s - loss: 1.2763 - mae: 0.9069 - 144ms/epoch - 6ms/step\n", "Epoch 15/30\n", "25/25 - 0s - loss: 1.1931 - mae: 0.8552 - 136ms/epoch - 5ms/step\n", "Epoch 16/30\n", "25/25 - 0s - loss: 1.2232 - mae: 0.8649 - 125ms/epoch - 5ms/step\n", "Epoch 17/30\n", "25/25 - 0s - loss: 1.3285 - mae: 0.9049 - 122ms/epoch - 5ms/step\n", "Epoch 18/30\n", "25/25 - 0s - loss: 1.2162 - mae: 0.8614 - 127ms/epoch - 5ms/step\n", "Epoch 19/30\n", "25/25 - 0s - loss: 1.0204 - mae: 0.7888 - 132ms/epoch - 5ms/step\n", "Epoch 20/30\n", "25/25 - 0s - loss: 1.2871 - mae: 0.8817 - 130ms/epoch - 5ms/step\n", "Epoch 21/30\n", "25/25 - 0s - loss: 0.9439 - mae: 0.7681 - 133ms/epoch - 5ms/step\n", "Epoch 22/30\n", "25/25 - 0s - loss: 0.8779 - mae: 0.7186 - 133ms/epoch - 5ms/step\n", "Epoch 23/30\n", "25/25 - 0s - loss: 1.0503 - mae: 0.7944 - 136ms/epoch - 5ms/step\n", "Epoch 24/30\n", "25/25 - 0s - loss: 0.8557 - mae: 0.7086 - 135ms/epoch - 5ms/step\n", "Epoch 25/30\n", "25/25 - 0s - loss: 0.9817 - mae: 0.7609 - 140ms/epoch - 6ms/step\n", "Epoch 26/30\n", "25/25 - 0s - loss: 0.9477 - mae: 0.7526 - 141ms/epoch - 6ms/step\n", "Epoch 27/30\n", "25/25 - 0s - loss: 0.8323 - mae: 0.6940 - 139ms/epoch - 6ms/step\n", "Epoch 28/30\n", "25/25 - 0s - loss: 0.7691 - mae: 0.6539 - 136ms/epoch - 5ms/step\n", "Epoch 29/30\n", "25/25 - 0s - loss: 0.9321 - mae: 0.7294 - 139ms/epoch - 6ms/step\n", "Epoch 30/30\n", "25/25 - 0s - loss: 1.1205 - mae: 0.8220 - 137ms/epoch - 5ms/step\n", "11/11 [==============================] - 0s 2ms/step - loss: 0.7704 - mae: 0.6764\n", "mse = 0.770\n", "mae = 0.676\n" ] } ], "source": [ "functional_model.compile('adam', 'mse', ['mae'])\n", "functional_model.fit(x_train, y_train, epochs=30, verbose=2)\n", "mse, mae = functional_model.evaluate(x_test, y_test)\n", "print(f\"{mse = :.3f}\\n{mae = :.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Option 3: Keras Model subclassing\n", "\n", "Creating a custom Keras model allow for more flexibility. Let perform some random skip connections." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"my_model\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " gin_conv_6 (GINConv) multiple 21187 \n", " \n", " gin_conv_7 (GINConv) multiple 51073 \n", " \n", " gin_conv_8 (GINConv) multiple 49537 \n", " \n", " segment_pooling_readout_2 multiple 0 \n", " (SegmentPoolingReadout) \n", " \n", " dense_6 (Dense) multiple 197120 \n", " \n", " dense_7 (Dense) multiple 513 \n", " \n", "=================================================================\n", "Total params: 319430 (1.22 MB)\n", "Trainable params: 319430 (1.22 MB)\n", "Non-trainable params: 0 (0.00 Byte)\n", "_________________________________________________________________\n" ] } ], "source": [ "class MyModel(tf.keras.Model):\n", " def __init__(self, gnn_units=128, dense_units=512):\n", " super().__init__()\n", " self.gin_conv1 = layers.GINConv(gnn_units)\n", " self.gin_conv2 = layers.GINConv(gnn_units)\n", " self.gin_conv3 = layers.GINConv(gnn_units)\n", " self.readout = layers.Readout()\n", " self.dense_1 = tf.keras.layers.Dense(512, activation='relu')\n", " self.dense_2 = tf.keras.layers.Dense(1)\n", " \n", " def call(self, inputs):\n", " x0 = inputs\n", " x1 = self.gin_conv1(x0)\n", " x2 = self.gin_conv2(x1)\n", " x3 = self.gin_conv3(x2)\n", " x1 = self.readout(x1)\n", " x2 = self.readout(x2)\n", " x3 = self.readout(x3)\n", " x = tf.concat([x1, x2, x3], axis=1)\n", " x = self.dense_1(x)\n", " return self.dense_2(x)\n", " \n", " \n", "my_model = MyModel()\n", "\n", "my_model(x_train) # build\n", "\n", "my_model.summary()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "25/25 - 4s - loss: 6.9835 - mae: 2.0324 - 4s/epoch - 154ms/step\n", "Epoch 2/30\n", "25/25 - 0s - loss: 3.5191 - mae: 1.4901 - 130ms/epoch - 5ms/step\n", "Epoch 3/30\n", "25/25 - 0s - loss: 3.0028 - mae: 1.3773 - 127ms/epoch - 5ms/step\n", "Epoch 4/30\n", "25/25 - 0s - loss: 3.0950 - mae: 1.4046 - 127ms/epoch - 5ms/step\n", "Epoch 5/30\n", "25/25 - 0s - loss: 3.0432 - mae: 1.3907 - 124ms/epoch - 5ms/step\n", "Epoch 6/30\n", "25/25 - 0s - loss: 2.8056 - mae: 1.3432 - 134ms/epoch - 5ms/step\n", "Epoch 7/30\n", "25/25 - 0s - loss: 2.4625 - mae: 1.2536 - 135ms/epoch - 5ms/step\n", "Epoch 8/30\n", "25/25 - 0s - loss: 2.3279 - mae: 1.2204 - 134ms/epoch - 5ms/step\n", "Epoch 9/30\n", "25/25 - 0s - loss: 2.0883 - mae: 1.1648 - 165ms/epoch - 7ms/step\n", "Epoch 10/30\n", "25/25 - 0s - loss: 1.7619 - mae: 1.0528 - 214ms/epoch - 9ms/step\n", "Epoch 11/30\n", "25/25 - 0s - loss: 1.5656 - mae: 0.9921 - 185ms/epoch - 7ms/step\n", "Epoch 12/30\n", "25/25 - 0s - loss: 1.5067 - mae: 0.9590 - 163ms/epoch - 7ms/step\n", "Epoch 13/30\n", "25/25 - 0s - loss: 1.3769 - mae: 0.9156 - 154ms/epoch - 6ms/step\n", "Epoch 14/30\n", "25/25 - 0s - loss: 1.3757 - mae: 0.9190 - 125ms/epoch - 5ms/step\n", "Epoch 15/30\n", "25/25 - 0s - loss: 1.2272 - mae: 0.8522 - 131ms/epoch - 5ms/step\n", "Epoch 16/30\n", "25/25 - 0s - loss: 1.3874 - mae: 0.9079 - 127ms/epoch - 5ms/step\n", "Epoch 17/30\n", "25/25 - 0s - loss: 1.2847 - mae: 0.8731 - 129ms/epoch - 5ms/step\n", "Epoch 18/30\n", "25/25 - 0s - loss: 1.0207 - mae: 0.7711 - 126ms/epoch - 5ms/step\n", "Epoch 19/30\n", "25/25 - 0s - loss: 0.9642 - mae: 0.7421 - 125ms/epoch - 5ms/step\n", "Epoch 20/30\n", "25/25 - 0s - loss: 1.0932 - mae: 0.8059 - 128ms/epoch - 5ms/step\n", "Epoch 21/30\n", "25/25 - 0s - loss: 1.3735 - mae: 0.8976 - 126ms/epoch - 5ms/step\n", "Epoch 22/30\n", "25/25 - 0s - loss: 0.9599 - mae: 0.7704 - 124ms/epoch - 5ms/step\n", "Epoch 23/30\n", "25/25 - 0s - loss: 0.9566 - mae: 0.7487 - 122ms/epoch - 5ms/step\n", "Epoch 24/30\n", "25/25 - 0s - loss: 1.0032 - mae: 0.7642 - 127ms/epoch - 5ms/step\n", "Epoch 25/30\n", "25/25 - 0s - loss: 1.1856 - mae: 0.8178 - 117ms/epoch - 5ms/step\n", "Epoch 26/30\n", "25/25 - 0s - loss: 1.0833 - mae: 0.7991 - 125ms/epoch - 5ms/step\n", "Epoch 27/30\n", "25/25 - 0s - loss: 0.9256 - mae: 0.7378 - 124ms/epoch - 5ms/step\n", "Epoch 28/30\n", "25/25 - 0s - loss: 0.8480 - mae: 0.6993 - 126ms/epoch - 5ms/step\n", "Epoch 29/30\n", "25/25 - 0s - loss: 1.0073 - mae: 0.7551 - 125ms/epoch - 5ms/step\n", "Epoch 30/30\n", "25/25 - 0s - loss: 0.8465 - mae: 0.6929 - 126ms/epoch - 5ms/step\n", "11/11 [==============================] - 0s 3ms/step - loss: 0.7891 - mae: 0.6715\n", "mse = 0.789\n", "mae = 0.672\n" ] } ], "source": [ "my_model.compile('adam', 'mse', ['mae'])\n", "my_model.fit(x_train, y_train, epochs=30, verbose=2)\n", "mse, mae = my_model.evaluate(x_test, y_test)\n", "print(f\"{mse = :.3f}\\n{mae = :.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model with **tf.data.Dataset**" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "ds_train = ds_train.shuffle(800).batch(32)\n", "\n", "ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", "ds_test = ds_test.batch(32)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "25/25 - 4s - loss: 2.0662 - mae: 1.1023 - 4s/epoch - 169ms/step\n", "Epoch 2/30\n", "25/25 - 0s - loss: 0.9658 - mae: 0.7642 - 146ms/epoch - 6ms/step\n", "Epoch 3/30\n", "25/25 - 0s - loss: 0.7833 - mae: 0.6706 - 193ms/epoch - 8ms/step\n", "Epoch 4/30\n", "25/25 - 0s - loss: 0.8129 - mae: 0.6789 - 224ms/epoch - 9ms/step\n", "Epoch 5/30\n", "25/25 - 0s - loss: 0.7564 - mae: 0.6587 - 160ms/epoch - 6ms/step\n", "Epoch 6/30\n", "25/25 - 0s - loss: 0.6944 - mae: 0.6358 - 167ms/epoch - 7ms/step\n", "Epoch 7/30\n", "25/25 - 0s - loss: 0.7331 - mae: 0.6370 - 152ms/epoch - 6ms/step\n", "Epoch 8/30\n", "25/25 - 0s - loss: 0.7645 - mae: 0.6606 - 156ms/epoch - 6ms/step\n", "Epoch 9/30\n", "25/25 - 0s - loss: 0.7225 - mae: 0.6401 - 176ms/epoch - 7ms/step\n", "Epoch 10/30\n", "25/25 - 0s - loss: 0.6720 - mae: 0.6212 - 176ms/epoch - 7ms/step\n", "Epoch 11/30\n", "25/25 - 0s - loss: 0.6687 - mae: 0.6162 - 140ms/epoch - 6ms/step\n", "Epoch 12/30\n", "25/25 - 0s - loss: 0.7627 - mae: 0.6656 - 139ms/epoch - 6ms/step\n", "Epoch 13/30\n", "25/25 - 0s - loss: 0.8233 - mae: 0.6914 - 184ms/epoch - 7ms/step\n", "Epoch 14/30\n", "25/25 - 0s - loss: 0.6556 - mae: 0.6026 - 136ms/epoch - 5ms/step\n", "Epoch 15/30\n", "25/25 - 0s - loss: 0.7744 - mae: 0.6573 - 184ms/epoch - 7ms/step\n", "Epoch 16/30\n", "25/25 - 0s - loss: 0.8383 - mae: 0.6919 - 139ms/epoch - 6ms/step\n", "Epoch 17/30\n", "25/25 - 0s - loss: 0.7492 - mae: 0.6450 - 137ms/epoch - 5ms/step\n", "Epoch 18/30\n", "25/25 - 0s - loss: 0.5925 - mae: 0.5758 - 139ms/epoch - 6ms/step\n", "Epoch 19/30\n", "25/25 - 0s - loss: 0.5963 - mae: 0.5831 - 153ms/epoch - 6ms/step\n", "Epoch 20/30\n", "25/25 - 0s - loss: 0.6683 - mae: 0.6151 - 133ms/epoch - 5ms/step\n", "Epoch 21/30\n", "25/25 - 0s - loss: 0.7187 - mae: 0.6552 - 136ms/epoch - 5ms/step\n", "Epoch 22/30\n", "25/25 - 0s - loss: 0.6949 - mae: 0.6189 - 137ms/epoch - 5ms/step\n", "Epoch 23/30\n", "25/25 - 0s - loss: 0.6306 - mae: 0.5957 - 134ms/epoch - 5ms/step\n", "Epoch 24/30\n", "25/25 - 0s - loss: 0.5399 - mae: 0.5503 - 142ms/epoch - 6ms/step\n", "Epoch 25/30\n", "25/25 - 0s - loss: 0.5650 - mae: 0.5578 - 137ms/epoch - 5ms/step\n", "Epoch 26/30\n", "25/25 - 0s - loss: 0.6170 - mae: 0.5973 - 143ms/epoch - 6ms/step\n", "Epoch 27/30\n", "25/25 - 0s - loss: 0.5971 - mae: 0.5838 - 136ms/epoch - 5ms/step\n", "Epoch 28/30\n", "25/25 - 0s - loss: 0.5966 - mae: 0.5822 - 133ms/epoch - 5ms/step\n", "Epoch 29/30\n", "25/25 - 0s - loss: 0.5543 - mae: 0.5507 - 132ms/epoch - 5ms/step\n", "Epoch 30/30\n", "25/25 - 0s - loss: 0.6288 - mae: 0.5837 - 139ms/epoch - 6ms/step\n", "11/11 [==============================] - 0s 2ms/step - loss: 0.6447 - mae: 0.6042\n", "mse = 0.645\n", "mae = 0.604\n" ] } ], "source": [ "sequential_model.compile('adam', 'mse', ['mae'])\n", "sequential_model.fit(ds_train, epochs=30, verbose=2)\n", "mse, mae = sequential_model.evaluate(x_test, y_test)\n", "print(f\"{mse = :.3f}\\n{mae = :.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Save and load GNN **model** " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Option 1: with `tf.saved_model`" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(800, 1)\n" ] } ], "source": [ "import tempfile\n", "import shutil\n", "\n", "file = tempfile.NamedTemporaryFile()\n", "filename = file.name\n", "file.close()\n", "\n", "tf.saved_model.save(sequential_model, filename)\n", "loaded_model = tf.saved_model.load(filename)\n", "\n", "print(loaded_model(x_train).shape)\n", "\n", "shutil.rmtree(filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Option 2: with `tf.keras`" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "25/25 - 0s - loss: 0.5887 - mae: 0.5814 - 150ms/epoch - 6ms/step\n", "Epoch 2/5\n", "25/25 - 0s - loss: 0.5724 - mae: 0.5628 - 153ms/epoch - 6ms/step\n", "Epoch 3/5\n", "25/25 - 0s - loss: 0.5588 - mae: 0.5565 - 151ms/epoch - 6ms/step\n", "Epoch 4/5\n", "25/25 - 0s - loss: 0.5512 - mae: 0.5565 - 154ms/epoch - 6ms/step\n", "Epoch 5/5\n", "25/25 - 0s - loss: 0.4906 - mae: 0.5297 - 144ms/epoch - 6ms/step\n" ] } ], "source": [ "import tempfile\n", "import shutil\n", "\n", "file = tempfile.NamedTemporaryFile()\n", "filename = file.name\n", "file.close()\n", "\n", "sequential_model.save(filename)\n", "loaded_model = tf.keras.models.load_model(filename)\n", "\n", "loaded_model.fit(ds_train, epochs=5, verbose=2);\n", "\n", "shutil.rmtree(filename)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }