Interpretability with Monte Carlo Tree search#
Based on the paper Jin et al., Multi-Objective Molecule Generation using Interpretable Substructures and modified from Chemprop v1 interpret.py
Please scroll to after the helper functions to change the model and data input and run the interpretation algorithm
Note:
The interpret function does not yet work with additional atom or bond features, as the substructure extracted doesn’t necessarily have the corresponding additional atom or bond features readily available.
It currently only works with single molecule model
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Import packages#
[2]:
from dataclasses import dataclass, field
import math
from pathlib import Path
import time
from typing import Callable, Union, Iterable
from lightning import pytorch as pl
import numpy as np
import pandas as pd
from rdkit import Chem
import torch
from chemprop import data, featurizers, models
from chemprop.models import MPNN
Define helper function to make model predictions from SMILES#
[3]:
def make_prediction(
models: list[MPNN],
trainer: pl.Trainer,
smiles: list[str],
) -> np.ndarray:
"""Makes predictions on a list of SMILES.
Parameters
----------
models : list
A list of models to make predictions with.
smiles : list
A list of SMILES to make predictions on.
Returns
-------
list[list[float]]
A list of lists containing the predicted values.
"""
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smiles]
test_dset = data.MoleculeDataset(test_data)
test_loader = data.build_dataloader(
test_dset, batch_size=1, num_workers=0, shuffle=False
)
with torch.inference_mode():
sum_preds = []
for model in models:
predss = trainer.predict(model, test_loader)
preds = torch.cat(predss, 0)
preds = preds.cpu().numpy()
sum_preds.append(preds)
# Ensemble predictions
sum_preds = sum(sum_preds)
avg_preds = sum_preds / len(models)
return avg_preds
Classes/functions relevant to Monte Carlo Tree Search#
Mostly similar to the scripts from Chemprop v1 interpret.py with additional documentation
[4]:
@dataclass
class MCTSNode:
"""Represents a node in a Monte Carlo Tree Search.
Parameters
----------
smiles : str
The SMILES for the substructure at this node.
atoms : list
A list of atom indices in the substructure at this node.
W : float
The total action value, which indicates how likely the deletion will lead to a good rationale.
N : int
The visit count, which indicates how many times this node has been visited. It is used to balance exploration and exploitation.
P : float
The predicted property score of the new subgraphs' after the deletion, shown as R in the original paper.
"""
smiles: str
atoms: Iterable[int]
W: float = 0
N: int = 0
P: float = 0
children: list[...] = field(default_factory=list)
def __post_init__(self):
self.atoms = set(self.atoms)
def Q(self) -> float:
"""
Returns
-------
float
The mean action value of the node.
"""
return self.W / self.N if self.N > 0 else 0
def U(self, n: int, c_puct: float = 10.0) -> float:
"""
Parameters
----------
n : int
The sum of the visit count of this node's siblings.
c_puct : float
A constant that controls the level of exploration.
Returns
-------
float
The exploration value of the node.
"""
return c_puct * self.P * math.sqrt(n) / (1 + self.N)
[5]:
def find_clusters(mol: Chem.Mol) -> tuple[list[tuple[int, ...]], list[list[int]]]:
"""Finds clusters within the molecule. Jin et al. from [1]_ only allows deletion of one peripheral non-aromatic bond or one peripheral ring from each state,
so the clusters here are defined as non-ring bonds and the smallest set of smallest rings.
Parameters
----------
mol : RDKit molecule
The molecule to find clusters in.
Returns
-------
tuple
A tuple containing:
- list of tuples: Each tuple contains atoms in a cluster.
- list of int: Each atom's cluster index.
References
----------
.. [1] Jin, Wengong, Regina Barzilay, and Tommi Jaakkola. "Multi-objective molecule generation using interpretable substructures." International conference on machine learning. PMLR, 2020. https://arxiv.org/abs/2002.03244
"""
n_atoms = mol.GetNumAtoms()
if n_atoms == 1: # special case
return [(0,)], [[0]]
clusters = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
clusters.append((a1, a2))
ssr = [tuple(x) for x in Chem.GetSymmSSSR(mol)]
clusters.extend(ssr)
atom_cls = [[] for _ in range(n_atoms)]
for i in range(len(clusters)):
for atom in clusters[i]:
atom_cls[atom].append(i)
return clusters, atom_cls
[6]:
def extract_subgraph_from_mol(mol: Chem.Mol, selected_atoms: set[int]) -> tuple[Chem.Mol, list[int]]:
"""Extracts a subgraph from an RDKit molecule given a set of atom indices.
Parameters
----------
mol : RDKit molecule
The molecule from which to extract a subgraph.
selected_atoms : list of int
The indices of atoms which form the subgraph to be extracted.
Returns
-------
tuple
A tuple containing:
- RDKit molecule: The subgraph.
- list of int: Root atom indices from the selected indices.
"""
selected_atoms = set(selected_atoms)
roots = []
for idx in selected_atoms:
atom = mol.GetAtomWithIdx(idx)
bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in selected_atoms]
if len(bad_neis) > 0:
roots.append(idx)
new_mol = Chem.RWMol(mol)
for atom_idx in roots:
atom = new_mol.GetAtomWithIdx(atom_idx)
atom.SetAtomMapNum(1)
aroma_bonds = [
bond for bond in atom.GetBonds() if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC
]
aroma_bonds = [
bond
for bond in aroma_bonds
if bond.GetBeginAtom().GetIdx() in selected_atoms
and bond.GetEndAtom().GetIdx() in selected_atoms
]
if len(aroma_bonds) == 0:
atom.SetIsAromatic(False)
remove_atoms = [
atom.GetIdx() for atom in new_mol.GetAtoms() if atom.GetIdx() not in selected_atoms
]
remove_atoms = sorted(remove_atoms, reverse=True)
for atom in remove_atoms:
new_mol.RemoveAtom(atom)
return new_mol.GetMol(), roots
[7]:
def extract_subgraph(smiles: str, selected_atoms: set[int]) -> tuple[str, list[int]]:
"""Extracts a subgraph from a SMILES given a set of atom indices.
Parameters
----------
smiles : str
The SMILES string from which to extract a subgraph.
selected_atoms : list of int
The indices of atoms which form the subgraph to be extracted.
Returns
-------
tuple
A tuple containing:
- str: SMILES representing the subgraph.
- list of int: Root atom indices from the selected indices.
"""
# try with kekulization
mol = Chem.MolFromSmiles(smiles)
Chem.Kekulize(mol)
subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
try:
subgraph = Chem.MolToSmiles(subgraph, kekuleSmiles=True)
subgraph = Chem.MolFromSmiles(subgraph)
except Exception:
subgraph = None
mol = Chem.MolFromSmiles(smiles) # de-kekulize
if subgraph is not None and mol.HasSubstructMatch(subgraph):
return Chem.MolToSmiles(subgraph), roots
# If fails, try without kekulization
subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
subgraph = Chem.MolToSmiles(subgraph)
subgraph = Chem.MolFromSmiles(subgraph)
if subgraph is not None:
return Chem.MolToSmiles(subgraph), roots
else:
return None, None
[8]:
def mcts_rollout(
node: MCTSNode,
state_map: dict[str, MCTSNode],
orig_smiles: str,
clusters: list[set[int]],
atom_cls: list[set[int]],
nei_cls: list[set[int]],
scoring_function: Callable[[list[str]], list[float]],
min_atoms: int = 15,
c_puct: float = 10.0,
) -> float:
"""A Monte Carlo Tree Search rollout from a given MCTSNode.
Parameters
----------
node : MCTSNode
The MCTSNode from which to begin the rollout.
state_map : dict
A mapping from SMILES to MCTSNode.
orig_smiles : str
The original SMILES of the molecule.
clusters : list
Clusters of atoms.
atom_cls : list
Atom indices in the clusters.
nei_cls : list
Neighboring cluster indices.
scoring_function : function
A function for scoring subgraph SMILES using a Chemprop model.
min_atoms : int
The minimum number of atoms in a subgraph.
c_puct : float
The constant controlling the level of exploration.
Returns
-------
float
The score of this MCTS rollout.
"""
# Return if the number of atoms is less than the minimum
cur_atoms = node.atoms
if len(cur_atoms) <= min_atoms:
return node.P
# Expand if this node has never been visited
if len(node.children) == 0:
# Cluster indices whose all atoms are present in current subgraph
cur_cls = set([i for i, x in enumerate(clusters) if x <= cur_atoms])
for i in cur_cls:
# Leaf atoms are atoms that are only involved in one cluster.
leaf_atoms = [a for a in clusters[i] if len(atom_cls[a] & cur_cls) == 1]
# This checks
# 1. If there is only one neighbor cluster in the current subgraph (so that we don't produce unconnected graphs), or
# 2. If the cluster has only two atoms and the current subgraph has only one leaf atom.
# If either of the conditions is met, remove the leaf atoms in the current cluster.
if len(nei_cls[i] & cur_cls) == 1 or len(clusters[i]) == 2 and len(leaf_atoms) == 1:
new_atoms = cur_atoms - set(leaf_atoms)
new_smiles, _ = extract_subgraph(orig_smiles, new_atoms)
if new_smiles in state_map:
new_node = state_map[new_smiles] # merge identical states
else:
new_node = MCTSNode(new_smiles, new_atoms)
if new_smiles:
node.children.append(new_node)
state_map[node.smiles] = node
if len(node.children) == 0:
return node.P # cannot find leaves
scores = scoring_function([x.smiles for x in node.children])
for child, score in zip(node.children, scores):
child.P = score
sum_count = sum(c.N for c in node.children)
selected_node = max(node.children, key=lambda x: x.Q() + x.U(sum_count, c_puct=c_puct))
v = mcts_rollout(
selected_node,
state_map,
orig_smiles,
clusters,
atom_cls,
nei_cls,
scoring_function,
min_atoms=min_atoms,
c_puct=c_puct,
)
selected_node.W += v
selected_node.N += 1
return v
[9]:
def mcts(
smiles: str,
scoring_function: Callable[[list[str]], list[float]],
n_rollout: int,
max_atoms: int,
prop_delta: float,
min_atoms: int = 15,
c_puct: int = 10,
) -> list[MCTSNode]:
"""Runs the Monte Carlo Tree Search algorithm.
Parameters
----------
smiles : str
The SMILES of the molecule to perform the search on.
scoring_function : function
A function for scoring subgraph SMILES using a Chemprop model.
n_rollout : int
The number of MCTS rollouts to perform.
max_atoms : int
The maximum number of atoms allowed in an extracted rationale.
prop_delta : float
The minimum required property value for a satisfactory rationale.
min_atoms : int
The minimum number of atoms in a subgraph.
c_puct : float
The constant controlling the level of exploration.
Returns
-------
list
A list of rationales each represented by a MCTSNode.
"""
mol = Chem.MolFromSmiles(smiles)
clusters, atom_cls = find_clusters(mol)
nei_cls = [0] * len(clusters)
for i, cls in enumerate(clusters):
nei_cls[i] = [nei for atom in cls for nei in atom_cls[atom]]
nei_cls[i] = set(nei_cls[i]) - {i}
clusters[i] = set(list(cls))
for a in range(len(atom_cls)):
atom_cls[a] = set(atom_cls[a])
root = MCTSNode(smiles, set(range(mol.GetNumAtoms())))
state_map = {smiles: root}
for _ in range(n_rollout):
mcts_rollout(
root,
state_map,
smiles,
clusters,
atom_cls,
nei_cls,
scoring_function,
min_atoms=min_atoms,
c_puct=c_puct,
)
rationales = [
node
for _, node in state_map.items()
if len(node.atoms) <= max_atoms and node.P >= prop_delta
]
return rationales
Load model#
[10]:
chemprop_dir = Path.cwd().parent
model_path = (
chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.pt"
) # path to model checkpoint (.ckpt) or model file (.pt)
[11]:
mpnn = models.MPNN.load_from_file(model_path) # this is a dummy model for testing purposes
mpnn
[11]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
Load data to run interpretation for#
[12]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
smiles_column = "smiles"
[13]:
df_test = pd.read_csv(test_path)
df_test
[13]:
| smiles | lipo | |
|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
| ... | ... | ... |
| 95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
| 96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
| 97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
| 98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
| 99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows × 2 columns
Set up trainer#
[14]:
trainer = pl.Trainer(logger=None, enable_progress_bar=False, accelerator="cpu", devices=1)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Running interpretation#
[15]:
# MCTS options
rollout = 10 # number of MCTS rollouts to perform. If mol.GetNumAtoms() > 50, consider setting n_rollout = 1 to avoid long computation time
c_puct = 10.0 # constant that controls the level of exploration
max_atoms = 20 # maximum number of atoms allowed in an extracted rationale
min_atoms = 8 # minimum number of atoms in an extracted rationale
prop_delta = 0.5 # Minimum score to count as positive.
# In this algorithm, if the predicted property from the substructure if larger than prop_delta, the substructure is considered satisfactory.
# This value depends on the property you want to interpret. 0.5 is a dummy value for demonstration purposes
num_rationales_to_keep = 5 # number of rationales to keep for each molecule
[16]:
# Define the scoring function. "Score" for a substructure is the predicted property value of the substructure.
models = [mpnn]
property_for_interpretation = "lipo"
property_id = (
df_test.columns.get_loc(property_for_interpretation) - 1
) # property index in the dataset; -1 for the SMILES column
def scoring_function(smiles: list[str]) -> list[float]:
return make_prediction(
models=models,
trainer=trainer,
smiles=smiles,
)[:, property_id]
[17]:
# only use the first 5 SMILES for demonstration purposes
all_smiles = df_test[smiles_column].tolist()[:5]
all_smiles
[17]:
['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1']
[18]:
%%time
results_df = {"smiles": [], property_for_interpretation: []}
for i in range(num_rationales_to_keep):
results_df[f"rationale_{i}"] = []
results_df[f"rationale_{i}_score"] = []
for smiles in all_smiles:
score = scoring_function([smiles])[0]
if score > prop_delta:
rationales = mcts(
smiles=smiles,
scoring_function=scoring_function,
n_rollout=rollout,
max_atoms=max_atoms,
prop_delta=prop_delta,
min_atoms=min_atoms,
c_puct=c_puct,
)
else:
rationales = []
results_df["smiles"].append(smiles)
results_df[property_for_interpretation].append(score)
if len(rationales) == 0:
for i in range(num_rationales_to_keep):
results_df[f"rationale_{i}"].append(None)
results_df[f"rationale_{i}_score"].append(None)
else:
min_size = min(len(x.atoms) for x in rationales)
min_rationales = [x for x in rationales if len(x.atoms) == min_size]
rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)
for i in range(num_rationales_to_keep):
if i < len(rats):
results_df[f"rationale_{i}"].append(rats[i].smiles)
results_df[f"rationale_{i}_score"].append(rats[i].P)
else:
results_df[f"rationale_{i}"].append(None)
results_df[f"rationale_{i}_score"].append(None)
/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
[18:30:57] Can't kekulize mol. Unkekulized atoms: 10 11 12 13 14
[18:30:57] Can't kekulize mol. Unkekulized atoms: 11 12 13 14 15
[18:30:57] Can't kekulize mol. Unkekulized atoms: 8 9 10 11 12
[18:30:57] Can't kekulize mol. Unkekulized atoms: 7 8 9 10 11
[18:30:57] Can't kekulize mol. Unkekulized atoms: 1 2 3 4 5
[18:30:57] Can't kekulize mol. Unkekulized atoms: 0 1 3 4 5
[18:30:57] Can't kekulize mol. Unkekulized atoms: 0 1 2 3 4
[18:30:58] Can't kekulize mol. Unkekulized atoms: 11 12 13 14 15
[18:30:58] Can't kekulize mol. Unkekulized atoms: 8 9 10 11 12
[18:30:58] Can't kekulize mol. Unkekulized atoms: 7 8 9 10 11
[18:30:58] Can't kekulize mol. Unkekulized atoms: 10 11 12 13 14
CPU times: user 26min 27s, sys: 901 ms, total: 26min 28s
Wall time: 1min 43s
[19]:
results_df = pd.DataFrame(results_df)
results_df
[19]:
| smiles | lipo | rationale_0 | rationale_0_score | rationale_1 | rationale_1_score | rationale_2 | rationale_2_score | rationale_3 | rationale_3_score | rationale_4 | rationale_4_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 2.253542 | c1ccc2c(c1)n[cH:1][nH:1]2 | 2.275024 | None | NaN | None | NaN | None | NaN | None | NaN |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | 2.235016 | O=[SH:1]c1c[cH:1][cH:1]cc1[OH:1] | 2.252582 | c1c([OH:1])c([S:1][NH2:1])c[cH:1][cH:1]1 | 2.252185 | c1c(N[CH3:1])[cH:1]c[cH:1]c1[SH:1] | 2.251067 | c1c([S:1][NH2:1])[cH:1]cc([OH:1])[cH:1]1 | 2.250288 | c1c([NH2:1])[cH:1]c[cH:1]c1[S:1][NH2:1] | 2.249267 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 2.245891 | c1cc[cH:1]c([CH2:1][CH2:1][OH:1])c1 | 2.249289 | O=[CH:1][CH2:1]c1cccc[cH:1]1 | 2.249207 | c1cc[cH:1]c([C@@H]([CH3:1])[NH2:1])c1 | 2.247827 | Clc1ccccc1[CH2:1][NH2:1] | 2.245391 | Clc1ccccc1[CH2:1][CH3:1] | 2.243280 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 2.249847 | c1c([CH3:1])[nH]c2s[cH:1]cc12 | 2.267990 | Clc1cc2c[cH:1][nH]c2s1 | 2.267004 | O=C1N(C[CH3:1])[CH:1]=[CH:1]C[CH2:1]1 | 2.211323 | None | NaN | None | NaN |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 2.228097 | c1cc(C[CH2:1][NH2:1])c[cH:1]c1 | 2.247070 | c1cc(C[CH2:1][CH3:1])c[cH:1]c1 | 2.245314 | Cn1nc([CH3:1])cc1[CH2:1][NH2:1] | 2.225729 | C[CH2:1]c1cc([CH2:1][NH2:1])[nH:1]n1 | 2.223793 | c1c([CH3:1])n[nH:1]c1[CH2:1]N[CH3:1] | 2.223478 |