Source code for chemprop.models.utils

from os import PathLike

import torch

from chemprop.models.model import MPNN
from chemprop.models.mol_atom_bond import MolAtomBondMPNN
from chemprop.models.multi import MulticomponentMPNN


[docs] def save_model( path: PathLike, model: MPNN | MolAtomBondMPNN | MulticomponentMPNN, output_columns: list[str] | tuple[list[str] | None, list[str] | None, list[str] | None] | None = None, ) -> None: torch.save( { "hyper_parameters": model.hparams, "state_dict": model.state_dict(), "output_columns": output_columns, }, path, )
[docs] def load_model( path: PathLike, multicomponent: bool = False, mol_atom_bond: bool = False ) -> MPNN | MulticomponentMPNN | MolAtomBondMPNN: model_cls = [ [MPNN, MulticomponentMPNN], [MolAtomBondMPNN, "Atom/Bond predictions not supported for multicomponent"], ][mol_atom_bond][multicomponent] return model_cls.load_from_file(path, map_location=torch.device("cpu"))
[docs] def load_output_columns( path: PathLike, ) -> list[str] | tuple[list[str] | None, list[str] | None, list[str] | None] | None: model_file = torch.load(path, map_location=torch.device("cpu"), weights_only=False) return model_file.get("output_columns")