Chemprop MPNN models#
[1]:
from chemprop.models.model import MPNN
Composition#
A Chemprop MPNN model is made up of several submodules including a message passing layer, an aggregation layer, an optional batch normalization layer, and a predictor feed forward network layer. MPNN defines the training and predicting logic used by lightning when using a Chemprop model in their framework.
[2]:
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
mp = BondMessagePassing()
agg = NormAggregation()
ffn = RegressionFFN()
basic_model = MPNN(mp, agg, ffn)
basic_model
[2]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): NormAggregation()
(bn): Identity()
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): Identity()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
Batch normalization#
Batch normalization can improve training by keeping the inputs to the FFN small and centered around zero. It is off by default, but can be turned on.
[3]:
MPNN(mp, agg, ffn, batch_norm=True)
[3]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): NormAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): Identity()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
Optimizer#
MPNN also configures the optimizer used by lightning during training. The torch.optim.Adam optimizer is used with a Noam learning rate scheduler (defined in chemprop.scheduler.NoamLR). The following parameters are customizable:
number of warmup epochs, defaults to 2
the initial learning rate, defaults to \(10^{-4}\)
the max learning rate, defaults to \(10^{-3}\)
the final learning rate, defaults to \(10^{-4}\)
[4]:
model = MPNN(mp, agg, ffn, warmup_epochs=5, init_lr=1e-3, max_lr=1e-2, final_lr=1e-5)
Metrics#
During the validation and testing loops, lightning will use the metrics stored in MPNN to evaluate the current model’s performance. The MPNN has a default metric defined by the type of predictor used. Other metrics can be given to MPNN to use instead.
[5]:
from chemprop.nn import metrics
metrics_list = [metrics.RMSE(), metrics.MAE()]
model = MPNN(mp, agg, ffn, metrics=metrics_list)
Fingerprinting and encoding#
MPNN has two helper functions to get the hidden representations at different parts of the model. The fingerprint is the learned representation of the message passing layer after aggregation and batch normalization. The encoding is the hidden representation after a number of layers of the predictor. See the predictor notebook for more details. Note that the 0th encoding is equivalent to the fingerprint.
Example batch for the model. See the data notebooks for more details.
[6]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.data import build_dataloader
smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = build_dataloader(dataset)
batch = next(iter(dataloader))
bmg, V_d, X_d, *_ = batch
[ ]:
basic_model(bmg, V_d, X_d)
tensor([[0.0333],
[0.0331],
[0.0332]], grad_fn=<AddmmBackward0>)
[8]:
basic_model.fingerprint(bmg, V_d, X_d).shape
[8]:
torch.Size([3, 300])
[9]:
basic_model.encoding(bmg, V_d, X_d, i=1).shape
[9]:
torch.Size([3, 300])
[10]:
(basic_model.fingerprint(bmg, V_d, X_d) == basic_model.encoding(bmg, V_d, X_d, i=0)).all()
[10]:
tensor(True)