Scaling inputs and outputs#
[1]:
import torch
from chemprop.models import MPNN
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
from chemprop.nn.transforms import ScaleTransform, UnscaleTransform, GraphTransform
This is an example dataset with extra atom and bond features, extra atom descriptors, and extra datapoint descriptors.
[2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset
smis = ["CC", "CN", "CO", "CF", "CP", "CS", "CI"]
ys = np.random.rand(len(smis), 1) * 100
n_datapoints = len(smis)
n_atoms = 2
n_bonds = 1
n_extra_atom_features = 3
n_extra_bond_features = 4
n_extra_atom_descriptors = 5
n_extra_datapoint_descriptors = 6
extra_atom_features = np.random.rand(n_datapoints, n_atoms, n_extra_atom_features)
extra_bond_features = np.random.rand(n_datapoints, n_bonds, n_extra_bond_features)
extra_atom_descriptors = np.random.rand(n_datapoints, n_atoms, n_extra_atom_descriptors)
extra_datapoint_descriptors = np.random.rand(n_datapoints, n_extra_datapoint_descriptors)
datapoints = [
MoleculeDatapoint.from_smi(smi, y, x_d=x_d, V_f=V_f, E_f=E_f, V_d=V_d)
for smi, y, x_d, V_f, E_f, V_d in zip(
smis,
ys,
extra_datapoint_descriptors,
extra_atom_features,
extra_bond_features,
extra_atom_descriptors,
)
]
train_dset = MoleculeDataset(datapoints[:3])
val_dset = MoleculeDataset(datapoints[3:5])
test_dset = MoleculeDataset(datapoints[5:])
Scaling targets - FFN#
Scaling the target values before training can improve model performance and make training faster. The scaler for the targets should be fit to the training dataset and then applied to the validation dataset. This scaler is not applied to the test dataset. Instead the scaler is used to make an UnscaleTransform which is given to the predictor (FFN) layer and used automatically during inference.
Note that currently the output_transform is saved both in the model’s state_dict and and in the model’s hyperparameters. This may be changed in the future to align with lightning’s recommendations. You can ignore any messages about this.
[3]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)
# test_dset targets not scaled
output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
ffn = RegressionFFN(output_transform=output_transform)
Scaling extra atom and bond features - Message Passing#
The atom and bond features generated by Chemprop featurizers are either multi-hot or on the order of 1. We recommend scaling extra atom and bond features to also be on the order of 1. Like the target scaler, these scalers are fit to the training data, applied to the validation data, and then saved to the model (in this case the message passing layer) so that they are applied automatically to the test dataset during inference.
[4]:
V_f_scaler = train_dset.normalize_inputs("V_f")
E_f_scaler = train_dset.normalize_inputs("E_f")
val_dset.normalize_inputs("V_f", V_f_scaler)
val_dset.normalize_inputs("E_f", E_f_scaler)
[4]:
StandardScaler()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
StandardScaler()
The scalers are used to make ScaleTransforms. These are combined into a GraphTransform which is given to the message passing module. Note that ScaleTransform acts on the whole feature vector, not just the extra features. The ScaleTransform’s mean and scale arrays are padded with enough zeros and ones so that only the extra features are actually scaled. The amount of padding required is the length of the default features of the featurizer.
[5]:
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
featurizer = SimpleMoleculeMolGraphFeaturizer(
extra_atom_fdim=n_extra_atom_features, extra_bond_fdim=n_extra_bond_features
)
n_V_features = featurizer.atom_fdim - featurizer.extra_atom_fdim
n_E_features = featurizer.bond_fdim - featurizer.extra_bond_fdim
V_f_transform = ScaleTransform.from_standard_scaler(V_f_scaler, pad=n_V_features)
E_f_transform = ScaleTransform.from_standard_scaler(E_f_scaler, pad=n_E_features)
graph_transform = GraphTransform(V_f_transform, E_f_transform)
mp = BondMessagePassing(graph_transform=graph_transform)
If you only have one of extra atom features or extra bond features, you can set the transform for the unused option to torch.nn.Identity.
[6]:
graph_transform = GraphTransform(V_transform=torch.nn.Identity(), E_transform=E_f_transform)
Scaling extra atom descriptors - Message Passing#
The atom descriptors from message passing (before aggregation) are also likely to be on the order of 1 so extra atom descriptors should also be scaled. No padding is needed (unlike above) as this scaling is only applied to the extra atom descriptors. The ScaleTransform is given to the message passing module for use during inference.
[7]:
V_d_scaler = train_dset.normalize_inputs("V_d")
val_dset.normalize_inputs("V_d", V_d_scaler)
V_d_transform = ScaleTransform.from_standard_scaler(V_d_scaler)
mp = BondMessagePassing(V_d_transform=V_d_transform)
A GraphTransform and ScaleTransform can both be given to the message passing.
[8]:
mp = BondMessagePassing(graph_transform=graph_transform, V_d_transform=V_d_transform)
Scaling extra datapoint descriptors - MPNN#
The molecule/reaction descriptors from message passing (after aggregation) are batch normalized by default to be on the order of 1 (can be turned off, see the model notebook). Therefore we also recommended scaling the extra datapoint level descriptors. The ScaleTransform for this is given to the MPNN or MulticomponentMPNN module.
[9]:
X_d_scaler = train_dset.normalize_inputs("X_d")
val_dset.normalize_inputs("X_d", X_d_scaler)
X_d_transform = ScaleTransform.from_standard_scaler(X_d_scaler)
chemprop_model = MPNN(
BondMessagePassing(), NormAggregation(), RegressionFFN(), X_d_transform=X_d_transform
)