chemprop.models.multi
#
Module Contents#
Classes#
An |
- class chemprop.models.multi.MulticomponentMPNN(message_passing, agg, predictor, batch_norm=True, metrics=None, warmup_epochs=2, init_lr=0.0001, max_lr=0.001, final_lr=0.0001, X_d_transform=None)[source]#
Bases:
chemprop.models.model.MPNN
An
MPNN
is a sequence of message passing layers, an aggregation routine, and a predictor routine.The first two modules calculate learned fingerprints from an input molecule reaction graph, and the final module takes these learned fingerprints as input to calculate a final prediction. I.e., the following operation:
\[\mathtt{MPNN}(\mathcal{G}) = \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))\]The full model is trained end-to-end.
- Parameters:
message_passing (MessagePassing) – the message passing block to use to calculate learned fingerprints
agg (Aggregation) – the aggregation operation to use during molecule-level predictor
predictor (Predictor) – the function to use to calculate the final prediction
batch_norm (bool, default=True) – if True, apply batch normalization to the output of the aggregation operation
metrics (Iterable[Metric] | None, default=None) – the metrics to use to evaluate the model during training and evaluation
warmup_epochs (int, default=2) – the number of epochs to use for the learning rate warmup
init_lr (int, default=1e-4) – the initial learning rate
max_lr (float, default=1e-3) – the maximum learning rate
final_lr (float, default=1e-4) – the final learning rate
X_d_transform (chemprop.nn.transforms.ScaleTransform | None)
- Raises:
ValueError – if the output dimension of the message passing block does not match the input dimension of the predictor function
- fingerprint(bmgs, V_ds, X_d=None)[source]#
the learned fingerprints for the input molecules
- Parameters:
bmgs (Iterable[chemprop.data.BatchMolGraph])
V_ds (Iterable[torch.Tensor | None])
X_d (torch.Tensor | None)
- Return type:
torch.Tensor