Interpretation
chemprop.interpret.py uses a Monte Carlo Tree Search to interpret trained Chemprop models by identifying substructures of a molecule which are primarily responsible for Chemprop’s prediction.
- class chemprop.interpret.ChempropModel(args: InterpretArgs)[source]
A
ChempropModel
is a wrapper around aMoleculeModel
for interpretation.- Parameters:
args – A
InterpretArgs
object containing arguments for interpretation.
- class chemprop.interpret.MCTSNode(smiles: str, atoms: List[int], W: float = 0, N: int = 0, P: float = 0)[source]
A
MCTSNode
represents a node in a Monte Carlo Tree Search.- Parameters:
smiles – The SMILES for the substructure at this node.
atoms – A list of atom indices represented by this node.
W – The W value of this node.
N – The N value of this node.
P – The P value of this node.
- chemprop.interpret.chemprop_interpret() None [source]
Runs interpretation of a Chemprop model.
This is the entry point for the command line command
chemprop_interpret
.
- chemprop.interpret.extract_subgraph(smiles: str, selected_atoms: Set[int]) Tuple[str, List[int]] [source]
Extracts a subgraph from a SMILES given a set of atom indices.
- Parameters:
smiles – A SMILES from which to extract a subgraph.
selected_atoms – The atoms which form the subgraph to be extracted.
- Returns:
A tuple containing a SMILES representing the subgraph and a list of root atom indices from the selected indices.
- chemprop.interpret.find_clusters(mol: Mol) Tuple[List[Tuple[int, ...]], List[List[int]]] [source]
Finds clusters within the molecule.
- Parameters:
mol – An RDKit molecule.
- Returns:
A tuple containing a list of atom tuples representing the clusters and a list of lists of atoms in each cluster.
- chemprop.interpret.interpret(args: InterpretArgs) None [source]
Runs interpretation of a Chemprop model using the Monte Carlo Tree Search algorithm.
- Parameters:
args – A
InterpretArgs
object containing arguments for interpretation.
- chemprop.interpret.mcts(smiles: str, scoring_function: Callable[[List[str]], List[float]], n_rollout: int, max_atoms: int, prop_delta: float) List[MCTSNode] [source]
Runs the Monte Carlo Tree Search algorithm.
- Parameters:
smiles – The SMILES of the molecule to perform the search on.
scoring_function – A function for scoring subgraph SMILES using a Chemprop model.
n_rollout – THe number of MCTS rollouts to perform.
max_atoms – The maximum number of atoms allowed in an extracted rationale.
prop_delta – The minimum required property value for a satisfactory rationale.
- Returns:
A list of rationales each represented by a
MCTSNode
.
- chemprop.interpret.mcts_rollout(node: MCTSNode, state_map: Dict[str, MCTSNode], orig_smiles: str, clusters: List[Set[int]], atom_cls: List[Set[int]], nei_cls: List[Set[int]], scoring_function: Callable[[List[str]], List[float]]) float [source]
A Monte Carlo Tree Search rollout from a given
MCTSNode
.- Parameters:
node – The
MCTSNode
from which to begin the rollout.state_map – A mapping from SMILES to
MCTSNode
.orig_smiles – The original SMILES of the molecule.
clusters – Clusters of atoms.
atom_cls – Atom indices in the clusters.
nei_cls – Neighboring clusters.
scoring_function – A function for scoring subgraph SMILES using a Chemprop model.
- Returns:
The score of this MCTS rollout.