from typing import List, Union, Tuple
import numpy as np
from rdkit import Chem
import torch
import torch.nn as nn
from .mpn import MPN
from .ffn import build_ffn, MultiReadout
from chemprop.args import TrainArgs
from chemprop.features import BatchMolGraph
from chemprop.nn_utils import initialize_weights
[docs]class MoleculeModel(nn.Module):
"""A :class:`MoleculeModel` is a model which contains a message passing network following by feed-forward layers."""
def __init__(self, args: TrainArgs):
"""
:param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
"""
super(MoleculeModel, self).__init__()
self.classification = args.dataset_type == "classification"
self.multiclass = args.dataset_type == "multiclass"
self.loss_function = args.loss_function
if hasattr(args, "train_class_sizes"):
self.train_class_sizes = args.train_class_sizes
else:
self.train_class_sizes = None
# when using cross entropy losses, no sigmoid or softmax during training. But they are needed for mcc loss.
if self.classification or self.multiclass:
self.no_training_normalization = args.loss_function in [
"cross_entropy",
"binary_cross_entropy",
]
self.is_atom_bond_targets = args.is_atom_bond_targets
if self.is_atom_bond_targets:
self.atom_targets, self.bond_targets = args.atom_targets, args.bond_targets
self.atom_constraints, self.bond_constraints = (
args.atom_constraints,
args.bond_constraints,
)
self.adding_bond_types = args.adding_bond_types
self.relative_output_size = 1
if self.multiclass:
self.relative_output_size *= args.multiclass_num_classes
if self.loss_function == "mve":
self.relative_output_size *= 2 # return means and variances
if self.loss_function == "dirichlet" and self.classification:
self.relative_output_size *= (
2 # return dirichlet parameters for positive and negative class
)
if self.loss_function == "evidential":
self.relative_output_size *= (
4 # return four evidential parameters: gamma, lambda, alpha, beta
)
if self.classification:
self.sigmoid = nn.Sigmoid()
if self.multiclass:
self.multiclass_softmax = nn.Softmax(dim=2)
if self.loss_function in ["mve", "evidential", "dirichlet"]:
self.softplus = nn.Softplus()
self.create_encoder(args)
self.create_ffn(args)
initialize_weights(self)
[docs] def create_encoder(self, args: TrainArgs) -> None:
"""
Creates the message passing encoder for the model.
:param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
"""
self.encoder = MPN(args)
if args.checkpoint_frzn is not None:
if args.freeze_first_only: # Freeze only the first encoder
for param in list(self.encoder.encoder.children())[0].parameters():
param.requires_grad = False
else: # Freeze all encoders
for param in self.encoder.parameters():
param.requires_grad = False
[docs] def create_ffn(self, args: TrainArgs) -> None:
"""
Creates the feed-forward layers for the model.
:param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
"""
self.multiclass = args.dataset_type == "multiclass"
if self.multiclass:
self.num_classes = args.multiclass_num_classes
if args.features_only:
first_linear_dim = args.features_size
else:
if args.reaction_solvent:
first_linear_dim = args.hidden_size + args.hidden_size_solvent
else:
first_linear_dim = args.hidden_size * args.number_of_molecules
if args.use_input_features:
first_linear_dim += args.features_size
if args.atom_descriptors == "descriptor":
atom_first_linear_dim = first_linear_dim + args.atom_descriptors_size
else:
atom_first_linear_dim = first_linear_dim
if args.bond_descriptors == "descriptor":
bond_first_linear_dim = first_linear_dim + args.bond_descriptors_size
else:
bond_first_linear_dim = first_linear_dim
# Create FFN layers
if self.is_atom_bond_targets:
self.readout = MultiReadout(
atom_features_size=atom_first_linear_dim,
bond_features_size=bond_first_linear_dim,
atom_hidden_size=args.ffn_hidden_size + args.atom_descriptors_size,
bond_hidden_size=args.ffn_hidden_size + args.bond_descriptors_size,
num_layers=args.ffn_num_layers,
output_size=self.relative_output_size,
dropout=args.dropout,
activation=args.activation,
atom_constraints=args.atom_constraints,
bond_constraints=args.bond_constraints,
shared_ffn=args.shared_atom_bond_ffn,
weights_ffn_num_layers=args.weights_ffn_num_layers,
)
else:
self.readout = build_ffn(
first_linear_dim=atom_first_linear_dim,
hidden_size=args.ffn_hidden_size + args.atom_descriptors_size,
num_layers=args.ffn_num_layers,
output_size=self.relative_output_size * args.num_tasks,
dropout=args.dropout,
activation=args.activation,
dataset_type=args.dataset_type,
spectra_activation=args.spectra_activation,
)
if args.checkpoint_frzn is not None:
if args.frzn_ffn_layers > 0:
if self.is_atom_bond_targets:
if args.shared_atom_bond_ffn:
for param in list(self.readout.atom_ffn_base.parameters())[
0 : 2 * args.frzn_ffn_layers
]:
param.requires_grad = False
for param in list(self.readout.bond_ffn_base.parameters())[
0 : 2 * args.frzn_ffn_layers
]:
param.requires_grad = False
else:
for ffn in self.readout.ffn_list:
if ffn.constraint:
for param in list(ffn.ffn.parameters())[
0 : 2 * args.frzn_ffn_layers
]:
param.requires_grad = False
else:
for param in list(ffn.ffn_readout.parameters())[
0 : 2 * args.frzn_ffn_layers
]:
param.requires_grad = False
else:
for param in list(self.readout.parameters())[
0 : 2 * args.frzn_ffn_layers
]: # Freeze weights and bias for given number of layers
param.requires_grad = False
[docs] def fingerprint(
self,
batch: Union[
List[List[str]],
List[List[Chem.Mol]],
List[List[Tuple[Chem.Mol, Chem.Mol]]],
List[BatchMolGraph],
],
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None,
atom_features_batch: List[np.ndarray] = None,
bond_descriptors_batch: List[np.ndarray] = None,
bond_features_batch: List[np.ndarray] = None,
fingerprint_type: str = "MPN",
) -> torch.Tensor:
"""
Encodes the latent representations of the input molecules from intermediate stages of the model.
:param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
list of :class:`~chemprop.features.featurization.BatchMolGraph`.
The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),
the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:param atom_features_batch: A list of numpy arrays containing additional atom features.
:param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors.
:param bond_features_batch: A list of numpy arrays containing additional bond features.
:param fingerprint_type: The choice of which type of latent representation to return as the molecular fingerprint. Currently
supported MPN for the output of the MPNN portion of the model or last_FFN for the input to the final readout layer.
:return: The latent fingerprint vectors.
"""
if fingerprint_type == "MPN":
return self.encoder(
batch,
features_batch,
atom_descriptors_batch,
atom_features_batch,
bond_descriptors_batch,
bond_features_batch,
)
elif fingerprint_type == "last_FFN":
return self.readout[:-1](
self.encoder(
batch,
features_batch,
atom_descriptors_batch,
atom_features_batch,
bond_descriptors_batch,
bond_features_batch,
)
)
else:
raise ValueError(f"Unsupported fingerprint type {fingerprint_type}.")
[docs] def forward(
self,
batch: Union[
List[List[str]],
List[List[Chem.Mol]],
List[List[Tuple[Chem.Mol, Chem.Mol]]],
List[BatchMolGraph],
],
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None,
atom_features_batch: List[np.ndarray] = None,
bond_descriptors_batch: List[np.ndarray] = None,
bond_features_batch: List[np.ndarray] = None,
constraints_batch: List[torch.Tensor] = None,
bond_types_batch: List[torch.Tensor] = None,
) -> torch.Tensor:
"""
Runs the :class:`MoleculeModel` on input.
:param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
list of :class:`~chemprop.features.featurization.BatchMolGraph`.
The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),
the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:param atom_features_batch: A list of numpy arrays containing additional atom features.
:param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors.
:param bond_features_batch: A list of numpy arrays containing additional bond features.
:param constraints_batch: A list of PyTorch tensors which applies constraint on atomic/bond properties.
:param bond_types_batch: A list of PyTorch tensors storing bond types of each bond determined by RDKit molecules.
:return: The output of the :class:`MoleculeModel`, containing a list of property predictions.
"""
if self.is_atom_bond_targets:
encodings = self.encoder(
batch,
features_batch,
atom_descriptors_batch,
atom_features_batch,
bond_descriptors_batch,
bond_features_batch,
)
output = self.readout(encodings, constraints_batch, bond_types_batch)
else:
encodings = self.encoder(
batch,
features_batch,
atom_descriptors_batch,
atom_features_batch,
bond_descriptors_batch,
bond_features_batch,
)
output = self.readout(encodings)
# Don't apply sigmoid during training when using BCEWithLogitsLoss
if (
self.classification
and not (self.training and self.no_training_normalization)
and self.loss_function != "dirichlet"
):
if self.is_atom_bond_targets:
output = [self.sigmoid(x) for x in output]
else:
output = self.sigmoid(output)
if self.multiclass:
output = output.reshape(
(output.shape[0], -1, self.num_classes)
) # batch size x num targets x num classes per target
if (
not (self.training and self.no_training_normalization)
and self.loss_function != "dirichlet"
):
output = self.multiclass_softmax(
output
) # to get probabilities during evaluation, but not during training when using CrossEntropyLoss
# Modify multi-input loss functions
if self.loss_function == "mve":
if self.is_atom_bond_targets:
outputs = []
for x in output:
means, variances = torch.split(x, x.shape[1] // 2, dim=1)
variances = self.softplus(variances)
outputs.append(torch.cat([means, variances], axis=1))
return outputs
else:
means, variances = torch.split(output, output.shape[1] // 2, dim=1)
variances = self.softplus(variances)
output = torch.cat([means, variances], axis=1)
if self.loss_function == "evidential":
if self.is_atom_bond_targets:
outputs = []
for x in output:
means, lambdas, alphas, betas = torch.split(
x, x.shape[1] // 4, dim=1
)
lambdas = self.softplus(lambdas) # + min_val
alphas = (
self.softplus(alphas) + 1
) # + min_val # add 1 for numerical contraints of Gamma function
betas = self.softplus(betas) # + min_val
outputs.append(torch.cat([means, lambdas, alphas, betas], dim=1))
return outputs
else:
means, lambdas, alphas, betas = torch.split(
output, output.shape[1] // 4, dim=1
)
lambdas = self.softplus(lambdas) # + min_val
alphas = (
self.softplus(alphas) + 1
) # + min_val # add 1 for numerical contraints of Gamma function
betas = self.softplus(betas) # + min_val
output = torch.cat([means, lambdas, alphas, betas], dim=1)
if self.loss_function == "dirichlet":
if self.is_atom_bond_targets:
outputs = []
for x in output:
outputs.append(nn.functional.softplus(x) + 1)
return outputs
else:
output = nn.functional.softplus(output) + 1
return output