Source code for chemprop.models.mpn

from typing import List, Union, Tuple
from functools import reduce

import numpy as np
from rdkit import Chem
import torch
import torch.nn as nn

from chemprop.args import TrainArgs
from chemprop.features import BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph
from chemprop.nn_utils import index_select_ND, get_activation_function


[docs]class MPNEncoder(nn.Module): """An :class:`MPNEncoder` is a message passing neural network for encoding a molecule.""" def __init__(self, args: TrainArgs, atom_fdim: int, bond_fdim: int, hidden_size: int = None, bias: bool = None, depth: int = None): """ :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments. :param atom_fdim: Atom feature vector dimension. :param bond_fdim: Bond feature vector dimension. :param hidden_size: Hidden layers dimension. :param bias: Whether to add bias to linear layers. :param depth: Number of message passing steps. """ super(MPNEncoder, self).__init__() self.atom_fdim = atom_fdim self.bond_fdim = bond_fdim self.atom_messages = args.atom_messages self.hidden_size = hidden_size or args.hidden_size self.bias = bias or args.bias self.depth = depth or args.depth self.layers_per_message = 1 self.undirected = args.undirected self.device = args.device self.aggregation = args.aggregation self.aggregation_norm = args.aggregation_norm self.is_atom_bond_targets = args.is_atom_bond_targets # Dropout self.dropout = nn.Dropout(args.dropout) # Activation self.act_func = get_activation_function(args.activation) # Cached zeros self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) # Input input_dim = self.atom_fdim if self.atom_messages else self.bond_fdim self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias) if self.atom_messages: w_h_input_size = self.hidden_size + self.bond_fdim else: w_h_input_size = self.hidden_size self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias) self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size) if self.is_atom_bond_targets: self.W_o_b = nn.Linear(self.bond_fdim + self.hidden_size, self.hidden_size) if args.atom_descriptors == 'descriptor': self.atom_descriptors_size = args.atom_descriptors_size self.atom_descriptors_layer = nn.Linear(self.hidden_size + self.atom_descriptors_size, self.hidden_size + self.atom_descriptors_size,) if args.bond_descriptors == 'descriptor': self.bond_descriptors_size = args.bond_descriptors_size self.bond_descriptors_layer = nn.Linear(self.hidden_size + self.bond_descriptors_size, self.hidden_size + self.bond_descriptors_size,)
[docs] def forward(self, mol_graph: BatchMolGraph, atom_descriptors_batch: List[np.ndarray] = None, bond_descriptors_batch: List[np.ndarray] = None) -> torch.Tensor: """ Encodes a batch of molecular graphs. :param mol_graph: A :class:`~chemprop.features.featurization.BatchMolGraph` representing a batch of molecular graphs. :param atom_descriptors_batch: A list of numpy arrays containing additional atomic descriptors. :param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if atom_descriptors_batch is not None: atom_descriptors_batch = [np.zeros([1, atom_descriptors_batch[0].shape[1]])] + atom_descriptors_batch # padding the first with 0 to match the atom_hiddens atom_descriptors_batch = torch.from_numpy(np.concatenate(atom_descriptors_batch, axis=0)).float().to(self.device) f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(atom_messages=self.atom_messages) f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), b2revb.to(self.device) if self.is_atom_bond_targets: b2br = mol_graph.get_b2br().to(self.device) if bond_descriptors_batch is not None: forward_index = b2br[:, 0] backward_index = b2br[:, 1] descriptors_batch = np.concatenate(bond_descriptors_batch, axis=0) bond_descriptors_batch = np.zeros([descriptors_batch.shape[0] * 2 + 1, descriptors_batch.shape[1]]) for i, fi in enumerate(forward_index): bond_descriptors_batch[fi] = descriptors_batch[i] for i, fi in enumerate(backward_index): bond_descriptors_batch[fi] = descriptors_batch[i] bond_descriptors_batch = torch.from_numpy(bond_descriptors_batch).float().to(self.device) if self.atom_messages: a2a = mol_graph.get_a2a().to(self.device) # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout(message) # num_bonds x hidden # atom hidden a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout(atom_hiddens) # num_atoms x hidden # bond hidden if self.is_atom_bond_targets: b_input = torch.cat([f_bonds, message], dim=1) # num_bonds x (bond_fdim + hidden) bond_hiddens = self.act_func(self.W_o_b(b_input)) # num_bonds x hidden bond_hiddens = self.dropout(bond_hiddens) # num_bonds x hidden # concatenate the atom descriptors if atom_descriptors_batch is not None: if len(atom_hiddens) != len(atom_descriptors_batch): raise ValueError('The number of atoms is different from the length of the extra atom features') atom_hiddens = torch.cat([atom_hiddens, atom_descriptors_batch], dim=1) # num_atoms x (hidden + descriptor size) atom_hiddens = self.atom_descriptors_layer(atom_hiddens) # num_atoms x (hidden + descriptor size) atom_hiddens = self.dropout(atom_hiddens) # num_atoms x (hidden + descriptor size) # concatenate the bond descriptors if self.is_atom_bond_targets and bond_descriptors_batch is not None: if len(bond_hiddens) != len(bond_descriptors_batch): raise ValueError('The number of bonds is different from the length of the extra bond features') bond_hiddens = torch.cat([bond_hiddens, bond_descriptors_batch], dim=1) # num_bonds x (hidden + descriptor size) bond_hiddens = self.bond_descriptors_layer(bond_hiddens) # num_bonds x (hidden + descriptor size) bond_hiddens = self.dropout(bond_hiddens) # num_bonds x (hidden + descriptor size) # Readout if self.is_atom_bond_targets: return atom_hiddens, a_scope, bond_hiddens, b_scope, b2br # num_atoms x hidden, remove the first one which is zero padding mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) if self.aggregation == 'mean': mol_vec = mol_vec.sum(dim=0) / a_size elif self.aggregation == 'sum': mol_vec = mol_vec.sum(dim=0) elif self.aggregation == 'norm': mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) return mol_vecs # num_molecules x hidden
[docs]class MPN(nn.Module): """An :class:`MPN` is a wrapper around :class:`MPNEncoder` which featurizes input as needed.""" def __init__(self, args: TrainArgs, atom_fdim: int = None, bond_fdim: int = None): """ :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments. :param atom_fdim: Atom feature vector dimension. :param bond_fdim: Bond feature vector dimension. """ super(MPN, self).__init__() self.reaction = args.reaction self.reaction_solvent = args.reaction_solvent self.atom_fdim = atom_fdim or get_atom_fdim(overwrite_default_atom=args.overwrite_default_atom_features, is_reaction=self.reaction if self.reaction is not False else self.reaction_solvent) self.bond_fdim = bond_fdim or get_bond_fdim(overwrite_default_atom=args.overwrite_default_atom_features, overwrite_default_bond=args.overwrite_default_bond_features, atom_messages=args.atom_messages, is_reaction=self.reaction if self.reaction is not False else self.reaction_solvent) self.features_only = args.features_only self.use_input_features = args.use_input_features self.device = args.device self.atom_descriptors = args.atom_descriptors self.bond_descriptors = args.bond_descriptors self.overwrite_default_atom_features = args.overwrite_default_atom_features self.overwrite_default_bond_features = args.overwrite_default_bond_features if self.features_only: return if not self.reaction_solvent: if args.mpn_shared: self.encoder = nn.ModuleList([MPNEncoder(args, self.atom_fdim, self.bond_fdim)] * args.number_of_molecules) else: self.encoder = nn.ModuleList([MPNEncoder(args, self.atom_fdim, self.bond_fdim) for _ in range(args.number_of_molecules)]) else: self.encoder = MPNEncoder(args, self.atom_fdim, self.bond_fdim) # Set separate atom_fdim and bond_fdim for solvent molecules self.atom_fdim_solvent = get_atom_fdim(overwrite_default_atom=args.overwrite_default_atom_features, is_reaction=False) self.bond_fdim_solvent = get_bond_fdim(overwrite_default_atom=args.overwrite_default_atom_features, overwrite_default_bond=args.overwrite_default_bond_features, atom_messages=args.atom_messages, is_reaction=False) self.encoder_solvent = MPNEncoder(args, self.atom_fdim_solvent, self.bond_fdim_solvent, args.hidden_size_solvent, args.bias_solvent, args.depth_solvent)
[docs] def forward(self, batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]], features_batch: List[np.ndarray] = None, atom_descriptors_batch: List[np.ndarray] = None, atom_features_batch: List[np.ndarray] = None, bond_descriptors_batch: List[np.ndarray] = None, bond_features_batch: List[np.ndarray] = None) -> torch.Tensor: """ Encodes a batch of molecules. :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a list of :class:`~chemprop.features.featurization.BatchMolGraph`. The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch), the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint). :param features_batch: A list of numpy arrays containing additional features. :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors. :param atom_features_batch: A list of numpy arrays containing additional atom features. :param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors. :param bond_features_batch: A list of numpy arrays containing additional bond features. :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if type(batch[0]) != BatchMolGraph: # Group first molecules, second molecules, etc for mol2graph batch = [[mols[i] for mols in batch] for i in range(len(batch[0]))] # TODO: handle atom_descriptors_batch with multiple molecules per input if self.atom_descriptors == 'feature': if len(batch) > 1: raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') batch = [ mol2graph( mols=b, atom_features_batch=atom_features_batch, bond_features_batch=bond_features_batch, overwrite_default_atom_features=self.overwrite_default_atom_features, overwrite_default_bond_features=self.overwrite_default_bond_features ) for b in batch ] elif self.bond_descriptors == 'feature': if len(batch) > 1: raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') batch = [ mol2graph( mols=b, bond_features_batch=bond_features_batch, overwrite_default_atom_features=self.overwrite_default_atom_features, overwrite_default_bond_features=self.overwrite_default_bond_features ) for b in batch ] else: batch = [mol2graph(b) for b in batch] if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float().to(self.device) if self.features_only: return features_batch if self.atom_descriptors == 'descriptor' or self.bond_descriptors == 'descriptor': if len(batch) > 1: raise NotImplementedError('Atom descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') encodings = [enc(ba, atom_descriptors_batch, bond_descriptors_batch) for enc, ba in zip(self.encoder, batch)] else: if not self.reaction_solvent: encodings = [enc(ba) for enc, ba in zip(self.encoder, batch)] else: encodings = [] for ba in batch: if ba.is_reaction: encodings.append(self.encoder(ba)) else: encodings.append(self.encoder_solvent(ba)) output = encodings[0] if len(encodings) == 1 else torch.cat(encodings, dim=1) if self.use_input_features: if len(features_batch.shape) == 1: features_batch = features_batch.view(1, -1) output = torch.cat([output, features_batch], dim=1) return output