Source code for chemprop.models.multi

import logging
from typing import Iterable

import torch
from torch import Tensor

from chemprop.data import BatchMolGraph, MulticomponentTrainingBatch
from chemprop.models.model import MPNN
from chemprop.nn import Aggregation, MulticomponentMessagePassing, Predictor
from chemprop.nn.metrics import ChempropMetric
from chemprop.nn.transforms import ScaleTransform

logger = logging.getLogger(__name__)


[docs] class MulticomponentMPNN(MPNN): def __init__( self, message_passing: MulticomponentMessagePassing, agg: Aggregation, predictor: Predictor, batch_norm: bool = False, metrics: Iterable[ChempropMetric] | None = None, warmup_epochs: int = 2, init_lr: float = 1e-4, max_lr: float = 1e-3, final_lr: float = 1e-4, X_d_transform: ScaleTransform | None = None, ): super().__init__( message_passing, agg, predictor, batch_norm, metrics, warmup_epochs, init_lr, max_lr, final_lr, X_d_transform, ) self.message_passing: MulticomponentMessagePassing
[docs] def fingerprint( self, bmgs: Iterable[BatchMolGraph], V_ds: Iterable[Tensor | None], X_d: Tensor | None = None, ) -> Tensor: H_vs: list[Tensor] = self.message_passing(bmgs, V_ds) Hs = [self.agg(H_v, bmg.batch) for H_v, bmg in zip(H_vs, bmgs)] H = torch.cat(Hs, 1) H = self.bn(H) return H if X_d is None else torch.cat((H, self.X_d_transform(X_d)), 1)
[docs] def on_validation_model_eval(self) -> None: self.eval() for block in self.message_passing.blocks: block.V_d_transform.train() block.graph_transform.train() self.X_d_transform.train() self.predictor.output_transform.train()
[docs] def get_batch_size(self, batch: MulticomponentTrainingBatch) -> int: return len(batch[0][0])
@classmethod def _load(cls, path, map_location, **submodules): try: d = torch.load(path, map_location, weights_only=False) except AttributeError: logger.error( f"Model loading failed! It is possible this checkpoint was generated in v2.0 and needs to be converted to v2.1\n Please run 'chemprop convert --conversion v2_0_to_v2_1 -i {path}' and load the converted checkpoint." ) try: hparams = d["hyper_parameters"] state_dict = d["state_dict"] except KeyError: raise KeyError(f"Could not find hyper parameters and/or state dict in {path}.") if hparams["metrics"] is not None: hparams["metrics"] = [ cls._rebuild_metric(metric) if not hasattr(metric, "_defaults") or (not torch.cuda.is_available() and metric.device.type != "cpu") else metric for metric in hparams["metrics"] ] if hparams["predictor"]["criterion"] is not None: metric = hparams["predictor"]["criterion"] if not hasattr(metric, "_defaults") or ( not torch.cuda.is_available() and metric.device.type != "cpu" ): hparams["predictor"]["criterion"] = cls._rebuild_metric(metric) hparams["message_passing"]["blocks"] = [ block_hparams.pop("cls")(**block_hparams) for block_hparams in hparams["message_passing"]["blocks"] ] submodules |= { key: hparams[key].pop("cls")(**hparams[key]) for key in ("message_passing", "agg", "predictor") if key not in submodules } return submodules, state_dict, hparams