{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Graph tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from molgraph import chemistry\n", "from molgraph import layers\n", "from molgraph import GraphTensor #####\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Construct a `GraphTensor`\n", "\n", "Although a `GraphTensor` can be constructed directly from its constructor, here we construct a `GraphTensor` from a `MolecularGraphEncoder`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=)\n" ] } ], "source": [ "atom_encoder = chemistry.Featurizer([\n", " chemistry.features.Symbol({'C', 'N', 'O'}, oov_size=1),\n", " chemistry.features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),\n", " chemistry.features.HydrogenDonor(),\n", " chemistry.features.HydrogenAcceptor(),\n", " chemistry.features.Hetero()\n", "])\n", "\n", "bond_encoder = chemistry.Featurizer([\n", " chemistry.features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),\n", " chemistry.features.Rotatable()\n", "])\n", "\n", "mol_encoder = chemistry.MolecularGraphEncoder(\n", " atom_encoder, bond_encoder, positional_encoding_dim=None)\n", "\n", "smiles_list = [\n", " 'OCC1OC(C(C1O)O)n1cnc2c1ncnc2N',\n", " 'C(C(=O)O)N',\n", " 'C1=CC(=CC=C1CC(C(=O)O)N)O'\n", "]\n", "\n", "graph_tensor = mol_encoder(smiles_list)\n", "\n", "print(graph_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.separate()` – Separate subgraphs of `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=)\n" ] } ], "source": [ "graph_tensor = graph_tensor.separate()\n", "print(graph_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.merge()` – Merge subgraphs of `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=)\n" ] } ], "source": [ "graph_tensor = graph_tensor.merge()\n", "print(graph_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.propagate()` – Propagate node informaton with the `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Node features before:\n", " tf.Tensor(\n", "[[0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]\n", " [0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1.]\n", " [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]\n", " [0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1.]\n", " [0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1.]\n", " [0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1.]\n", " [0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1.]], shape=(37, 11), dtype=float32)\n", "\n", "Node features after:\n", " tf.Tensor(\n", "[[0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 1. 0. 0. 0. 2. 1. 1. 1.]\n", " [0. 2. 0. 1. 0. 0. 0. 3. 0. 1. 1.]\n", " [0. 2. 0. 0. 0. 0. 0. 2. 0. 0. 0.]\n", " [0. 1. 1. 1. 0. 0. 1. 2. 0. 2. 2.]\n", " [0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]\n", " [0. 2. 0. 1. 0. 0. 0. 3. 1. 1. 1.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]\n", " [0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 2. 1. 0. 0. 0. 3. 0. 0. 1. 1.]\n", " [0. 1. 2. 0. 0. 0. 3. 0. 0. 2. 2.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 0. 2. 0. 0. 0. 2. 0. 0. 2. 2.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 1. 2. 0. 0. 0. 3. 0. 1. 2. 2.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1.]\n", " [0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 2. 0. 1. 0. 0. 3. 0. 1. 1. 1.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 2. 0. 0. 0. 0. 2. 0. 0. 0. 0.]\n", " [0. 3. 0. 0. 0. 0. 2. 1. 0. 0. 0.]\n", " [0. 2. 0. 0. 0. 0. 1. 1. 0. 0. 0.]\n", " [0. 2. 1. 0. 0. 0. 1. 2. 1. 1. 1.]\n", " [0. 1. 0. 2. 0. 0. 2. 1. 1. 1. 2.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]], shape=(37, 11), dtype=float32)\n" ] } ], "source": [ "print('Node features before:\\n', graph_tensor.node_feature, end='\\n\\n')\n", "graph_tensor = graph_tensor.propagate()\n", "print('Node features after:\\n', graph_tensor.node_feature)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.update()` – Update data of the `GraphTensor`\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=,\n", " edge_feature=,\n", " node_supplementary_data=)\n" ] } ], "source": [ "node_supplementary_data = tf.random.uniform(\n", " shape=graph_tensor.node_feature.shape[:-1] + [1])\n", "\n", "node_feature_updated = tf.random.uniform(\n", " shape=graph_tensor.node_feature.shape[:-1] + [128])\n", "\n", "# Both add new data and update existing data of the GraphTensor:\n", "graph_tensor = graph_tensor.update({\n", " 'node_supplementary_data': node_supplementary_data, \n", " 'node_feature': node_feature_updated\n", "})\n", "print(graph_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.remove()` – Remove data from `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n" ] } ], "source": [ "graph_tensor = graph_tensor.remove([\n", " 'node_supplementary_data', 'edge_feature'\n", "])\n", "print(graph_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `__getitem__` – Index lookup on the `GraphTensor`\n", "\n", "The `GraphTensor` can be indexed either by passing a `str` (to obtain a specific field of `GraphTensor`) or `int`, `list[int]` or `slice` (to extract specific subgraphs (molecules) from `GraphTensor`). (Alternatively, the `GraphTensor` can be passed to `tf.gather` to extract specific subgraphs.)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Complete graph:\n", "\n", "------------------------------------------------------------\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "------------------------------------------------------------\n", "Subgraph (2) and (3) of graph:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "------------------------------------------------------------\n", "Subgraph (2) and (3) of graph:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n" ] } ], "source": [ "print(\"Complete graph:\\n\")\n", "print(\"---\" * 20)\n", "print(graph_tensor, end='\\n\\n')\n", "\n", "print(\"---\" * 20)\n", "print(\"Subgraph (2) and (3) of graph:\\n\")\n", "print(graph_tensor[[1, 2]], end='\\n\\n')\n", "\n", "print(\"---\" * 20)\n", "print(\"Subgraph (2) and (3) of graph:\\n\")\n", "print(graph_tensor[:2], end='\\n\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `__getattr__` – Attribute lookup on the `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Access `node_feature` field:\n", "\n", "------------------------------------------------------------\n", "tf.Tensor(\n", "[[0.30606592 0.01332998 0.28550065 ... 0.30522108 0.43709052 0.2496804 ]\n", " [0.47505558 0.6802629 0.12628877 ... 0.54731417 0.85908985 0.01080072]\n", " [0.32505012 0.16541815 0.9268564 ... 0.19977057 0.6975106 0.63107324]\n", " ...\n", " [0.06981373 0.0497787 0.7329197 ... 0.72168195 0.992267 0.4002931 ]\n", " [0.6254629 0.77454865 0.4750824 ... 0.21217322 0.10769343 0.71567035]\n", " [0.29524624 0.7836231 0.7198993 ... 0.94255567 0.926514 0.62505746]], shape=(37, 128), dtype=float32)\n", "\n", "------------------------------------------------------------\n", "Access `edge_src` field:\n", "\n", "tf.Tensor(\n", "[ 0 1 1 2 2 2 3 3 4 4 4 5 5 5 6 6 6 7 8 9 9 9 10 10\n", " 11 11 12 12 12 13 13 13 14 14 15 15 16 16 17 17 17 18 19 19 20 20 20 21\n", " 22 23 24 24 25 25 26 26 26 27 27 28 28 29 29 29 30 30 31 31 31 32 32 32\n", " 33 34 35 36], shape=(76,), dtype=int32)\n", "\n", "------------------------------------------------------------\n", "Access `graph_indicator` field:\n", "\n", "tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2], shape=(37,), dtype=int64)\n", "\n" ] } ], "source": [ "print(\"Access `node_feature` field:\\n\")\n", "print(\"---\" * 20)\n", "print(graph_tensor.node_feature, end='\\n\\n')\n", "\n", "print(\"---\" * 20)\n", "print(\"Access `edge_src` field:\\n\")\n", "print(graph_tensor.edge_src, end='\\n\\n')\n", "\n", "print(\"---\" * 20)\n", "print(\"Access `graph_indicator` field:\\n\")\n", "print(graph_tensor.graph_indicator, end='\\n\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `tf.concat` – Concatenating multiple `GraphTensor` instances" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Concatenating two graphs in non-ragged states:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "Inspect `graph_indicator`:\n", "\n", "tf.Tensor(\n", "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5], shape=(74,), dtype=int64)\n", "\n", "------------------------------------------------------------\n", "Concatenating two graphs in ragged states\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n" ] } ], "source": [ "print(\"Concatenating two graphs in non-ragged states:\\n\")\n", "graph_tensor_concat = tf.concat([\n", " graph_tensor, \n", " graph_tensor], axis=0)\n", "print(graph_tensor_concat, end='\\n\\n')\n", "print(\"Inspect `graph_indicator`:\\n\")\n", "print(graph_tensor_concat.graph_indicator, end='\\n\\n')\n", "\n", "print('---' * 20)\n", "print(\"Concatenating two graphs in ragged states\")\n", "graph_tensor_concat = tf.concat([\n", " graph_tensor.separate(), \n", " graph_tensor.separate()], axis=0)\n", "print(graph_tensor_concat, end='\\n\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `tf.split` – Splits a `GraphTensor` into multiple `GraphTensor` instances" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=),\n", " GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=),\n", " GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=),\n", " GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=),\n", " GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=),\n", " GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.split(graph_tensor_concat.merge(), num_or_size_splits=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.spec` – The spec of the `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTensor.Spec(sizes=TensorSpec(shape=(None,), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})\n" ] } ], "source": [ "print(graph_tensor.spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.shape` – Partial shape of the `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(partial) shape: (3, None, 128)\n" ] } ], "source": [ "print('(partial) shape:', graph_tensor.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.dtype` – Partial dtype of the `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(partial) dtype: float32\n" ] } ], "source": [ "print('(partial) dtype:', graph_tensor.dtype.name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `.rank` – Partial rank of the `GraphTensor` " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(partial) rank: 3\n" ] } ], "source": [ "print('(partial) rank: ', graph_tensor.rank)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `tf.data.Dataset` – Creating a TF dataset from a `GraphTensor`" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset object:\n", " <_TensorSliceDataset element_spec=GraphTensor.Spec(sizes=TensorSpec(shape=(), dtype=tf.int64, name=None), node_feature=TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), edge_src=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_dst=TensorSpec(shape=(None,), dtype=tf.int32, name=None), edge_feature=None, edge_weight=None, node_position=None, auxiliary={})>\n", "\n", "------------------------------------------------------------\n", "\n", "batch 1:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "------------------------------------------------------------\n", "\n", "batch 2:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "------------------------------------------------------------\n" ] } ], "source": [ "ds = tf.data.Dataset.from_tensor_slices(graph_tensor)\n", "print('Dataset object:\\n', ds)\n", "\n", "print('\\n' + '---' * 20)\n", "# Loop over dataset\n", "for i, x in enumerate(ds.batch(2).map(lambda x: x)):\n", " print(f\"\\nbatch {i + 1}:\\n\")\n", " print(x)\n", " print('\\n' + '---' * 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `layers` – Passing a `GraphTensor` to a layer\n", "\n", "The `GraphTensor` can be passed to GNN layers either as a single disjoint graph or subgraphs." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pass GraphTensor in non-ragged state:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n", "------------------------------------------------------------\n", "\n", "Pass GraphTensor in ragged state:\n", "\n", "GraphTensor(\n", " sizes=,\n", " node_feature=,\n", " edge_src=,\n", " edge_dst=)\n", "\n" ] } ], "source": [ "gin_conv = layers.GINConv(128)\n", "\n", "print(\"Pass GraphTensor in non-ragged state:\\n\")\n", "print(gin_conv(graph_tensor), end='\\n\\n')\n", "print('---' * 20)\n", "print('\\nPass GraphTensor in ragged state:\\n')\n", "print(gin_conv(graph_tensor.separate()), end='\\n\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `Model` – Passing a `GraphTensor` to a model" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using (graph_tensor, label) pair as input:\n", "\n", "Epoch 1/5\n", "2/2 [==============================] - 3s 8ms/step - loss: 0.3226\n", "Epoch 2/5\n", "2/2 [==============================] - 0s 7ms/step - loss: 8.0285\n", "Epoch 3/5\n", "2/2 [==============================] - 0s 8ms/step - loss: 4.7673\n", "Epoch 4/5\n", "2/2 [==============================] - 0s 6ms/step - loss: 1.8421\n", "Epoch 5/5\n", "2/2 [==============================] - 0s 6ms/step - loss: 0.1327\n", "\n", "------------------------------\n", "\n", "Using tf.data.Dataset as input:\n", "\n", "Epoch 1/5\n", "2/2 [==============================] - 0s 7ms/step - loss: 0.8891\n", "Epoch 2/5\n", "2/2 [==============================] - 0s 9ms/step - loss: 1.2614\n", "Epoch 3/5\n", "2/2 [==============================] - 0s 6ms/step - loss: 1.0532\n", "Epoch 4/5\n", "2/2 [==============================] - 0s 6ms/step - loss: 0.6833\n", "Epoch 5/5\n", "2/2 [==============================] - 0s 6ms/step - loss: 0.4317\n" ] } ], "source": [ "model = tf.keras.Sequential([\n", " layers.GCNConv(),\n", " layers.GCNConv(),\n", " layers.Readout(),\n", " keras.layers.Dense(1)\n", "])\n", "\n", "y_dummy = tf.constant([[1.], [2.], [3.]])\n", "\n", "\n", "model.compile('adam', 'huber')\n", "print(\"Using (graph_tensor, label) pair as input:\\n\")\n", "model.fit(graph_tensor, y_dummy, batch_size=2, epochs=5)\n", "\n", "print('\\n------------------------------\\n')\n", "print(\"Using tf.data.Dataset as input:\\n\")\n", "dataset = tf.data.Dataset.from_tensor_slices((graph_tensor, y_dummy))\n", "model.fit(dataset.batch(2), epochs=5);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }