Interpretability

[1]:
from molgraph.chemistry import MolecularGraphEncoder
from molgraph.chemistry import Featurizer
from molgraph.chemistry import features

from molgraph.models import GradientActivationMapping
from molgraph.models import IntegratedSaliencyMapping

import tensorflow as tf

import numpy as np
import pandas as pd

np.set_printoptions(suppress=True)

Construct a MolecularGraphEncoder

[2]:
atom_encoder = Featurizer([
    features.Symbol({'C', 'N', 'O'}, oov_size=1),
    features.Hybridization({'SP', 'SP2', 'SP3'}, oov_size=1),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.Hetero()
])

bond_encoder = Featurizer([
    features.BondType({'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'}),
    features.Rotatable()
])

mol_encoder = MolecularGraphEncoder(atom_encoder, bond_encoder, positional_encoding_dim=None)
mol_encoder
[2]:
MolecularGraphEncoder(atom_encoder=Featurizer(features=[Symbol(allowable_set=['[OOV:0]', 'C', 'N', 'O'], ordinal=False, oov_size=1), Hybridization(allowable_set=['[OOV:0]', 'SP', 'SP2', 'SP3'], ordinal=False, oov_size=1), HydrogenDonor(), HydrogenAcceptor(), Hetero()]), bond_encoder=Featurizer(features=[BondType(allowable_set=['AROMATIC', 'DOUBLE', 'SINGLE', 'TRIPLE'], ordinal=False, oov_size=0), Rotatable()]), positional_encoding_dim=None, self_loops=False)

Obtain dataset

[3]:
path = tf.keras.utils.get_file(
    fname='ESOL.csv',
    origin='http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/ESOL.csv',
)
df = pd.read_csv(path)
df.head(3)
[3]:
Compound ID ESOL predicted log solubility in mols per litre Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings Number of Rotatable Bonds Polar Surface Area measured log solubility in mols per litre smiles
0 Amigdalin -0.974 1 457.432 7 3 7 202.32 -0.77 OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...
1 Fenfuram -2.885 1 201.225 1 2 2 42.24 -3.30 Cc1occc1C(=O)Nc2ccccc2
2 citral -2.579 1 152.237 0 0 4 17.07 -2.06 CC(C)=CCCC(C)=CC(=O)

Obtain SMILES xand associated labels y

[4]:
x, y = df['smiles'].values, df['measured log solubility in mols per litre'].values

Obtain GraphTensor from x

[5]:
x_data = mol_encoder(x)

print(x_data, end='\n\n')
print('node_feature shape:', x_data.node_feature.shape)
print('edge_dst shape:    ', x_data.edge_dst.shape)
print('edge_src shape:    ', x_data.edge_src.shape)
print('edge_feature shape:', x_data.edge_feature.shape)
GraphTensor(
  sizes=<tf.Tensor: shape=(1128,), dtype=int64>,
  node_feature=<tf.Tensor: shape=(14991, 11), dtype=float32>,
  edge_src=<tf.Tensor: shape=(30856,), dtype=int32>,
  edge_dst=<tf.Tensor: shape=(30856,), dtype=int32>,
  edge_feature=<tf.Tensor: shape=(30856, 5), dtype=float32>)

node_feature shape: (14991, 11)
edge_dst shape:     (30856,)
edge_src shape:     (30856,)
edge_feature shape: (30856, 5)

1. Build Keras GNN model with GNN layers

[34]:
from molgraph.layers import GCNConv, Readout

sequential_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=x_data.spec),
    GCNConv(128, name='conv_1'),
    GCNConv(128, name='conv_2'),
    GCNConv(128, name='conv_3'),
    GCNConv(128, name='conv_4'),
    Readout(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(1)
])

sequential_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 conv_1 (GCNConv)            (None, None, 128)         4608

 conv_2 (GCNConv)            (None, None, 128)         33152

 conv_3 (GCNConv)            (None, None, 128)         33152

 segment_pooling_readout_6   (None, 128)               0
 (SegmentPoolingReadout)

 dense_12 (Dense)            (None, 512)               66048

 dense_13 (Dense)            (None, 1)                 513

=================================================================
Total params: 137473 (537.00 KB)
Trainable params: 137473 (537.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

2. Compile and fit GNN model

[35]:
sequential_model.compile('adam', 'mse', ['mae'])
sequential_model.fit(x_data, y, epochs=50, verbose=2);
Epoch 1/50
36/36 - 3s - loss: 4.2682 - mae: 1.6230 - 3s/epoch - 84ms/step
Epoch 2/50
36/36 - 0s - loss: 3.3620 - mae: 1.4817 - 179ms/epoch - 5ms/step
Epoch 3/50
36/36 - 0s - loss: 3.3828 - mae: 1.4810 - 188ms/epoch - 5ms/step
Epoch 4/50
36/36 - 0s - loss: 3.1614 - mae: 1.4371 - 186ms/epoch - 5ms/step
Epoch 5/50
36/36 - 0s - loss: 2.9661 - mae: 1.3852 - 183ms/epoch - 5ms/step
Epoch 6/50
36/36 - 0s - loss: 2.8416 - mae: 1.3507 - 188ms/epoch - 5ms/step
Epoch 7/50
36/36 - 0s - loss: 2.7258 - mae: 1.3335 - 176ms/epoch - 5ms/step
Epoch 8/50
36/36 - 0s - loss: 2.5141 - mae: 1.2632 - 186ms/epoch - 5ms/step
Epoch 9/50
36/36 - 0s - loss: 2.6962 - mae: 1.2898 - 186ms/epoch - 5ms/step
Epoch 10/50
36/36 - 0s - loss: 2.5491 - mae: 1.2679 - 180ms/epoch - 5ms/step
Epoch 11/50
36/36 - 0s - loss: 2.3768 - mae: 1.2089 - 180ms/epoch - 5ms/step
Epoch 12/50
36/36 - 0s - loss: 2.3904 - mae: 1.2302 - 177ms/epoch - 5ms/step
Epoch 13/50
36/36 - 0s - loss: 2.1581 - mae: 1.1444 - 185ms/epoch - 5ms/step
Epoch 14/50
36/36 - 0s - loss: 2.4210 - mae: 1.2226 - 187ms/epoch - 5ms/step
Epoch 15/50
36/36 - 0s - loss: 2.2568 - mae: 1.1956 - 197ms/epoch - 5ms/step
Epoch 16/50
36/36 - 0s - loss: 2.1405 - mae: 1.1333 - 201ms/epoch - 6ms/step
Epoch 17/50
36/36 - 0s - loss: 1.9765 - mae: 1.0873 - 203ms/epoch - 6ms/step
Epoch 18/50
36/36 - 0s - loss: 1.9263 - mae: 1.0707 - 204ms/epoch - 6ms/step
Epoch 19/50
36/36 - 0s - loss: 1.9688 - mae: 1.0793 - 186ms/epoch - 5ms/step
Epoch 20/50
36/36 - 0s - loss: 1.8725 - mae: 1.0718 - 170ms/epoch - 5ms/step
Epoch 21/50
36/36 - 0s - loss: 1.9170 - mae: 1.0576 - 176ms/epoch - 5ms/step
Epoch 22/50
36/36 - 0s - loss: 1.9656 - mae: 1.0912 - 181ms/epoch - 5ms/step
Epoch 23/50
36/36 - 0s - loss: 1.7675 - mae: 1.0313 - 185ms/epoch - 5ms/step
Epoch 24/50
36/36 - 0s - loss: 1.7883 - mae: 1.0274 - 172ms/epoch - 5ms/step
Epoch 25/50
36/36 - 0s - loss: 1.9393 - mae: 1.0915 - 182ms/epoch - 5ms/step
Epoch 26/50
36/36 - 0s - loss: 1.7453 - mae: 1.0155 - 179ms/epoch - 5ms/step
Epoch 27/50
36/36 - 0s - loss: 1.7054 - mae: 1.0131 - 182ms/epoch - 5ms/step
Epoch 28/50
36/36 - 0s - loss: 1.7013 - mae: 1.0174 - 179ms/epoch - 5ms/step
Epoch 29/50
36/36 - 0s - loss: 1.7120 - mae: 1.0070 - 179ms/epoch - 5ms/step
Epoch 30/50
36/36 - 0s - loss: 1.7977 - mae: 1.0334 - 180ms/epoch - 5ms/step
Epoch 31/50
36/36 - 0s - loss: 1.6766 - mae: 0.9960 - 186ms/epoch - 5ms/step
Epoch 32/50
36/36 - 0s - loss: 1.6374 - mae: 0.9735 - 180ms/epoch - 5ms/step
Epoch 33/50
36/36 - 0s - loss: 1.6437 - mae: 0.9936 - 180ms/epoch - 5ms/step
Epoch 34/50
36/36 - 0s - loss: 1.5678 - mae: 0.9722 - 184ms/epoch - 5ms/step
Epoch 35/50
36/36 - 0s - loss: 1.4954 - mae: 0.9241 - 179ms/epoch - 5ms/step
Epoch 36/50
36/36 - 0s - loss: 1.5489 - mae: 0.9540 - 175ms/epoch - 5ms/step
Epoch 37/50
36/36 - 0s - loss: 1.4601 - mae: 0.9223 - 183ms/epoch - 5ms/step
Epoch 38/50
36/36 - 0s - loss: 1.5156 - mae: 0.9418 - 176ms/epoch - 5ms/step
Epoch 39/50
36/36 - 0s - loss: 1.5280 - mae: 0.9391 - 178ms/epoch - 5ms/step
Epoch 40/50
36/36 - 0s - loss: 1.5135 - mae: 0.9387 - 179ms/epoch - 5ms/step
Epoch 41/50
36/36 - 0s - loss: 1.4961 - mae: 0.9250 - 181ms/epoch - 5ms/step
Epoch 42/50
36/36 - 0s - loss: 1.4935 - mae: 0.9296 - 176ms/epoch - 5ms/step
Epoch 43/50
36/36 - 0s - loss: 1.4748 - mae: 0.9282 - 179ms/epoch - 5ms/step
Epoch 44/50
36/36 - 0s - loss: 1.4566 - mae: 0.9082 - 170ms/epoch - 5ms/step
Epoch 45/50
36/36 - 0s - loss: 1.3972 - mae: 0.9027 - 186ms/epoch - 5ms/step
Epoch 46/50
36/36 - 0s - loss: 1.4849 - mae: 0.9236 - 176ms/epoch - 5ms/step
Epoch 47/50
36/36 - 0s - loss: 1.3397 - mae: 0.8696 - 183ms/epoch - 5ms/step
Epoch 48/50
36/36 - 0s - loss: 1.4117 - mae: 0.8968 - 172ms/epoch - 5ms/step
Epoch 49/50
36/36 - 0s - loss: 1.2797 - mae: 0.8465 - 177ms/epoch - 5ms/step
Epoch 50/50
36/36 - 0s - loss: 1.2962 - mae: 0.8538 - 179ms/epoch - 5ms/step

3. Pass GNN model to GradientActivationMapping

[36]:
gam_model = GradientActivationMapping(
    sequential_model,
    ['conv_1', 'conv_2', 'conv_3', 'conv_4'],
    output_activation=None,
    discard_negative_values=False,
)

gam = gam_model(x_data.separate())

4. Visualize maps on molecule

[9]:
from molgraph.chemistry import vis

vis.visualize_maps(molecule=x[42], maps=gam[42])
[9]:
../../_images/examples_walk_through_04_activation_maps_17_0.png