chemprop.models.multi#

Attributes#

Classes#

MulticomponentMPNN

An MPNN is a sequence of message passing layers, an aggregation routine, and a

Module Contents#

chemprop.models.multi.logger#
class chemprop.models.multi.MulticomponentMPNN(message_passing, agg, predictor, batch_norm=False, metrics=None, warmup_epochs=2, init_lr=0.0001, max_lr=0.001, final_lr=0.0001, X_d_transform=None)[source]#

Bases: chemprop.models.model.MPNN

An MPNN is a sequence of message passing layers, an aggregation routine, and a predictor routine.

The first two modules calculate learned fingerprints from an input molecule or reaction graph, and the final module takes these learned fingerprints as input to calculate a final prediction. I.e., the following operation:

\[\mathtt{MPNN}(\mathcal{G}) = \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))\]

The full model is trained end-to-end.

Parameters:
  • message_passing (MessagePassing) – the message passing block to use to calculate learned fingerprints

  • agg (Aggregation) – the aggregation operation to use during molecule-level prediction

  • predictor (Predictor) – the function to use to calculate the final prediction

  • batch_norm (bool, default=False) – if True, apply batch normalization to the output of the aggregation operation

  • metrics (Iterable[Metric] | None, default=None) – the metrics to use to evaluate the model during training and evaluation

  • warmup_epochs (int, default=2) – the number of epochs to use for the learning rate warmup

  • init_lr (int, default=1e-4) – the initial learning rate

  • max_lr (float, default=1e-3) – the maximum learning rate

  • final_lr (float, default=1e-4) – the final learning rate

  • X_d_transform (chemprop.nn.transforms.ScaleTransform | None)

Raises:

ValueError – if the output dimension of the message passing block does not match the input dimension of the predictor function

message_passing: chemprop.nn.MulticomponentMessagePassing#
fingerprint(bmgs, V_ds, X_d=None)[source]#

the learned fingerprints for the input molecules

Parameters:
Return type:

torch.Tensor

on_validation_model_eval()[source]#

Called when the validation loop starts.

The validation loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_validation_model_train().

Return type:

None

get_batch_size(batch)[source]#
Parameters:

batch (chemprop.data.MulticomponentTrainingBatch)

Return type:

int