chemprop.models.multi#

Module Contents#

Classes#

MulticomponentMPNN

An MPNN is a sequence of message passing layers, an aggregation routine, and a

class chemprop.models.multi.MulticomponentMPNN(message_passing, agg, predictor, batch_norm=True, metrics=None, warmup_epochs=2, init_lr=0.0001, max_lr=0.001, final_lr=0.0001, X_d_transform=None)[source]#

Bases: chemprop.models.model.MPNN

An MPNN is a sequence of message passing layers, an aggregation routine, and a predictor routine.

The first two modules calculate learned fingerprints from an input molecule reaction graph, and the final module takes these learned fingerprints as input to calculate a final prediction. I.e., the following operation:

\[\mathtt{MPNN}(\mathcal{G}) = \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))\]

The full model is trained end-to-end.

Parameters:
  • message_passing (MessagePassing) – the message passing block to use to calculate learned fingerprints

  • agg (Aggregation) – the aggregation operation to use during molecule-level predictor

  • predictor (Predictor) – the function to use to calculate the final prediction

  • batch_norm (bool, default=True) – if True, apply batch normalization to the output of the aggregation operation

  • metrics (Iterable[Metric] | None, default=None) – the metrics to use to evaluate the model during training and evaluation

  • warmup_epochs (int, default=2) – the number of epochs to use for the learning rate warmup

  • init_lr (int, default=1e-4) – the initial learning rate

  • max_lr (float, default=1e-3) – the maximum learning rate

  • final_lr (float, default=1e-4) – the final learning rate

  • X_d_transform (chemprop.nn.transforms.ScaleTransform | None)

Raises:

ValueError – if the output dimension of the message passing block does not match the input dimension of the predictor function

fingerprint(bmgs, V_ds, X_d=None)[source]#

the learned fingerprints for the input molecules

Parameters:
Return type:

torch.Tensor

classmethod load_submodules(checkpoint_path, **kwargs)[source]#
classmethod load_from_file(model_path, map_location=None, strict=True)[source]#
Return type:

chemprop.models.model.MPNN