chemprop.models.multi#
Attributes#
Classes#
An |
Module Contents#
- chemprop.models.multi.logger#
- class chemprop.models.multi.MulticomponentMPNN(message_passing, agg, predictor, batch_norm=False, metrics=None, warmup_epochs=2, init_lr=0.0001, max_lr=0.001, final_lr=0.0001, X_d_transform=None)[source]#
Bases:
chemprop.models.model.MPNNAn
MPNNis a sequence of message passing layers, an aggregation routine, and a predictor routine.The first two modules calculate learned fingerprints from an input molecule or reaction graph, and the final module takes these learned fingerprints as input to calculate a final prediction. I.e., the following operation:
\[\mathtt{MPNN}(\mathcal{G}) = \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))\]The full model is trained end-to-end.
- Parameters:
message_passing (MessagePassing) – the message passing block to use to calculate learned fingerprints
agg (Aggregation) – the aggregation operation to use during molecule-level prediction
predictor (Predictor) – the function to use to calculate the final prediction
batch_norm (bool, default=False) – if True, apply batch normalization to the output of the aggregation operation
metrics (Iterable[Metric] | None, default=None) – the metrics to use to evaluate the model during training and evaluation
warmup_epochs (int, default=2) – the number of epochs to use for the learning rate warmup
init_lr (int, default=1e-4) – the initial learning rate
max_lr (float, default=1e-3) – the maximum learning rate
final_lr (float, default=1e-4) – the final learning rate
X_d_transform (chemprop.nn.transforms.ScaleTransform | None)
- Raises:
ValueError – if the output dimension of the message passing block does not match the input dimension of the predictor function
- message_passing: chemprop.nn.MulticomponentMessagePassing#
- fingerprint(bmgs, V_ds, X_d=None)[source]#
the learned fingerprints for the input molecules
- Parameters:
bmgs (Iterable[chemprop.data.BatchMolGraph])
V_ds (Iterable[torch.Tensor | None])
X_d (torch.Tensor | None)
- Return type:
torch.Tensor
- on_validation_model_eval()[source]#
Called when the validation loop starts.
The validation loop by default calls
.eval()on the LightningModule before it starts. Override this hook to change the behavior. See alsoon_validation_model_train().- Return type:
None