Source code for chemprop.data.datasets

from dataclasses import dataclass, field
from functools import cached_property
import logging
from typing import NamedTuple, TypeAlias

import numpy as np
from numpy.typing import ArrayLike
from rdkit import Chem
from rdkit.Chem import Mol
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset

from chemprop.data.datapoints import (
    LazyMoleculeDatapoint,
    MolAtomBondDatapoint,
    MoleculeDatapoint,
    ReactionDatapoint,
)
from chemprop.data.molgraph import MolGraph
from chemprop.featurizers.base import Featurizer
from chemprop.featurizers.molgraph import (
    BatchCuikMolGraph,
    CGRFeaturizer,
    CuikmolmakerMolGraphFeaturizer,
    SimpleMoleculeMolGraphFeaturizer,
)
from chemprop.featurizers.molgraph.cache import MolGraphCache, MolGraphCacheOnTheFly
from chemprop.types import Rxn

logger = logging.getLogger(__name__)


[docs] class Datum(NamedTuple): """a singular training data point""" mg: MolGraph V_d: np.ndarray | None x_d: np.ndarray | None y: np.ndarray | None weight: float lt_mask: np.ndarray | None gt_mask: np.ndarray | None
[docs] class MolAtomBondDatum(NamedTuple): """a singular training data point that supports atom and bond level targets""" mg: MolGraph V_d: np.ndarray | None E_d: np.ndarray | None x_d: np.ndarray | None ys: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None] weight: float lt_masks: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None] gt_masks: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None] constraints: tuple[np.ndarray | None, np.ndarray | None]
[docs] class CuikBatchedDatum(NamedTuple): """a cuik-molmaker batch of data points""" bmg: BatchCuikMolGraph V_d: np.ndarray X_d: np.ndarray Y: np.ndarray weights: np.ndarray lt_mask: np.ndarray gt_mask: np.ndarray
MolGraphDataset: TypeAlias = Dataset[Datum] MolAtomBondGraphDataset: TypeAlias = Dataset[MolAtomBondDatum] class _MolGraphDatasetMixin: def __len__(self) -> int: return len(self.data) @cached_property def _Y(self) -> np.ndarray: """the raw targets of the dataset""" return np.array([d.y for d in self.data], float) @property def Y(self) -> np.ndarray: """the (scaled) targets of the dataset""" return self.__Y @Y.setter def Y(self, Y: ArrayLike): self._validate_attribute(Y, "targets") self.__Y = np.array(Y, float) @cached_property def _X_d(self) -> np.ndarray: """the raw extra descriptors of the dataset""" return np.array([d.x_d for d in self.data]) @property def X_d(self) -> np.ndarray: """the (scaled) extra descriptors of the dataset""" return self.__X_d @X_d.setter def X_d(self, X_d: ArrayLike): self._validate_attribute(X_d, "extra descriptors") self.__X_d = np.array(X_d) @property def weights(self) -> np.ndarray: return np.array([d.weight for d in self.data]) @property def gt_mask(self) -> np.ndarray: return np.array([d.gt_mask for d in self.data]) @property def lt_mask(self) -> np.ndarray: return np.array([d.lt_mask for d in self.data]) @property def t(self) -> int | None: return self.data[0].t if len(self.data) > 0 else None @property def d_xd(self) -> int: """the extra molecule descriptor dimension, if any""" return 0 if self.X_d[0] is None else self.X_d.shape[1] @property def names(self) -> list[str]: return [d.name for d in self.data] def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler: """Normalizes the targets of this dataset using a :obj:`StandardScaler` The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for each task independently. NOTE: This should only be used for regression datasets. Returns ------- StandardScaler a scaler fit to the targets. """ if scaler is None: scaler = StandardScaler().fit(self._Y) self.Y = scaler.transform(self._Y) return scaler def normalize_inputs( self, key: str = "X_d", scaler: StandardScaler | None = None ) -> StandardScaler: VALID_KEYS = {"X_d"} if key not in VALID_KEYS: raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") X = self._X_d if self._X_d[0] is not None else None if X is None: return scaler if scaler is None: scaler = StandardScaler().fit(X) self.X_d = scaler.transform(X) return scaler def reset(self): """Reset the atom and bond features; atom and extra descriptors; and targets of each datapoint to their initial, unnormalized values.""" self.__Y = self._Y self.__X_d = self._X_d def _validate_attribute(self, X: np.ndarray, label: str): if not len(self.data) == len(X): raise ValueError( f"number of molecules ({len(self.data)}) and {label} ({len(X)}) " "must have same length!" )
[docs] @dataclass class MoleculeDataset(_MolGraphDatasetMixin, MolGraphDataset): r"""A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s A :class:`MoleculeDataset` produces featurized data for input to a :class:`MPNN` model. Typically, data featurization is performed on-the-fly and parallelized across multiple workers via the :class:`~torch.utils.data DataLoader` class. However, for small datasets, it may be more efficient to featurize the data in advance and cache the results. This can be done by setting ``MoleculeDataset.cache=True``. Parameters ---------- data : Iterable[MoleculeDatapoint] the data from which to create a dataset featurizer : MoleculeFeaturizer the featurizer with which to generate MolGraphs of the molecules n_workers : int, optional number of workers to use for cache calculation """ data: list[MoleculeDatapoint] featurizer: Featurizer[Mol, MolGraph] = field(default_factory=SimpleMoleculeMolGraphFeaturizer) n_workers: int = 0
[docs] def __post_init__(self): if self.data is None: raise ValueError("Data cannot be None!") self.reset() self.cache = False
[docs] def __getitem__(self, idx: int) -> Datum: d = self.data[idx] mg = self.mg_cache[idx] return Datum(mg, self.V_ds[idx], self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
@property def cache(self) -> bool: return self.__cache @cache.setter def cache(self, cache: bool = False): self.__cache = cache self._init_cache() def _init_cache(self): """initialize the cache""" if self.cache: self.mg_cache = MolGraphCache( self.mols, self.V_fs, self.E_fs, self.featurizer, n_workers=self.n_workers ) else: self.mg_cache = MolGraphCacheOnTheFly(self.mols, self.V_fs, self.E_fs, self.featurizer) @property def smiles(self) -> list[str]: """the SMILES strings associated with the dataset""" return [Chem.MolToSmiles(d.mol) for d in self.data] @property def mols(self) -> list[Chem.Mol]: """the molecules associated with the dataset""" return [d.mol for d in self.data] @property def _V_fs(self) -> list[np.ndarray]: """the raw atom features of the dataset""" return [d.V_f for d in self.data] @property def V_fs(self) -> list[np.ndarray]: """the (scaled) atom descriptors of the dataset""" return self.__V_fs @V_fs.setter def V_fs(self, V_fs: list[np.ndarray]): """the (scaled) atom features of the dataset""" self._validate_attribute(V_fs, "atom features") self.__V_fs = V_fs self._init_cache() @property def _E_fs(self) -> list[np.ndarray]: """the raw bond features of the dataset""" return [d.E_f for d in self.data] @property def E_fs(self) -> list[np.ndarray]: """the (scaled) bond features of the dataset""" return self.__E_fs @E_fs.setter def E_fs(self, E_fs: list[np.ndarray]): self._validate_attribute(E_fs, "bond features") self.__E_fs = E_fs self._init_cache() @property def _V_ds(self) -> list[np.ndarray]: """the raw atom descriptors of the dataset""" return [d.V_d for d in self.data] @property def V_ds(self) -> list[np.ndarray]: """the (scaled) atom descriptors of the dataset""" return self.__V_ds @V_ds.setter def V_ds(self, V_ds: list[np.ndarray]): self._validate_attribute(V_ds, "atom descriptors") self.__V_ds = V_ds @property def d_vf(self) -> int: """the extra atom feature dimension, if any""" return 0 if self.V_fs[0] is None else self.V_fs[0].shape[1] @property def d_ef(self) -> int: """the extra bond feature dimension, if any""" return 0 if self.E_fs[0] is None else self.E_fs[0].shape[1] @property def d_vd(self) -> int: """the extra atom descriptor dimension, if any""" return 0 if self.V_ds[0] is None else self.V_ds[0].shape[1]
[docs] def normalize_inputs( self, key: str = "X_d", scaler: StandardScaler | None = None ) -> StandardScaler: VALID_KEYS = {"X_d", "V_f", "E_f", "V_d"} match key: case "X_d": X = None if self.d_xd == 0 else self._X_d case "V_f": X = None if self.d_vf == 0 else np.concatenate(self._V_fs, axis=0) case "E_f": X = None if self.d_ef == 0 else np.concatenate(self._E_fs, axis=0) case "V_d": X = None if self.d_vd == 0 else np.concatenate(self._V_ds, axis=0) case _: raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") if X is None: return scaler if scaler is None: scaler = StandardScaler().fit(X) match key: case "X_d": self.X_d = scaler.transform(X) case "V_f": self.V_fs = [scaler.transform(V_f) if V_f.size > 0 else V_f for V_f in self._V_fs] case "E_f": self.E_fs = [scaler.transform(E_f) if E_f.size > 0 else E_f for E_f in self._E_fs] case "V_d": self.V_ds = [scaler.transform(V_d) if V_d.size > 0 else V_d for V_d in self._V_ds] case _: raise RuntimeError("unreachable code reached!") return scaler
[docs] def reset(self): """Reset the atom and bond features; atom and extra descriptors; and targets of each datapoint to their initial, unnormalized values.""" super().reset() self.__V_fs = self._V_fs self.__E_fs = self._E_fs self.__V_ds = self._V_ds
[docs] @dataclass class CuikmolmakerDataset(MoleculeDataset): r"""A :class:`CuikmolmakerDataset` composed of :class:`LazyMoleculeDatapoint`\s and a :class:`CuikmolmakerMolGraphFeaturizer` A :class:`CuikmolmakerDataset` produces featurized data for a batch of molecules for ingestion by a :class:`MPNN` model. Data featurization is always performed on-the-fly and using the cuik-molmaker package. This batched processing is significantly faster and consumes less memory than the default featurization method when caching is not possible. Parameters ---------- data : Iterable[LazyMoleculeDatapoint] the data from which to create a dataset featurizer : CuikmolmakerMolGraphFeaturizer the featurizer with which to generate MolGraphs of the molecules """ data: list[LazyMoleculeDatapoint] featurizer: CuikmolmakerMolGraphFeaturizer = field( default_factory=CuikmolmakerMolGraphFeaturizer ) @MoleculeDataset.cache.setter def cache(self, cache: bool = False): if cache: raise NotImplementedError("CuikmolmakerDataset is meant to be used without caching!") def _init_cache(self): pass @property def smiles(self) -> list[str]: return [d.smiles for d in self.data]
[docs] def __getitem__(self, idx: int) -> Datum: d = self.data[idx] bmg = self.featurizer([d.smiles], self.V_fs[idx], self.E_fs[idx]) mg = MolGraph( bmg.V.numpy(), bmg.E.numpy(), bmg.edge_index.numpy(), bmg.rev_edge_index.numpy() ) return Datum(mg, self.V_ds[idx], self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
[docs] def __getitems__(self, indexes: list[int]) -> CuikBatchedDatum: smiles_list = [self.data[idx].smiles for idx in indexes] V_f = np.concat([self.V_fs[idx] for idx in indexes]) if self.V_fs[0] is not None else None E_f = np.concat([self.E_fs[idx] for idx in indexes]) if self.E_fs[0] is not None else None bmg = self.featurizer(smiles_list, V_f, E_f) V_d = np.concat([self.V_ds[idx] for idx in indexes]) if self.V_ds[0] is not None else None X_d = self.X_d[indexes] if self.X_d[0] is not None else None Y = self.Y[indexes] if self.Y[0] is not None else None weights = self.weights[indexes] lt_mask = self.lt_mask[indexes] if self.lt_mask[0] is not None else None gt_mask = self.gt_mask[indexes] if self.gt_mask[0] is not None else None return CuikBatchedDatum(bmg, V_d, X_d, Y, weights, lt_mask, gt_mask)
[docs] class MolAtomBondDataset(MoleculeDataset, MolAtomBondGraphDataset): data: list[MolAtomBondDatapoint]
[docs] def __getitem__(self, idx: int) -> MolAtomBondDatum: d = self.data[idx] mg = self.mg_cache[idx] return MolAtomBondDatum( mg, self.V_ds[idx], self.E_ds[idx], self.X_d[idx], [ self.Y[idx] if isinstance(self.Y[idx], np.ndarray) else None, self.atom_Y[idx] if self.atom_Y is not None else None, self.bond_Y[idx] if self.bond_Y is not None else None, ], d.weight, [d.lt_mask, d.atom_lt_mask, d.bond_lt_mask], [d.gt_mask, d.atom_gt_mask, d.bond_gt_mask], [d.atom_constraint, d.bond_constraint], )
@property def _atom_Y(self) -> list[np.ndarray]: """the raw atom targets of the dataset""" return [d.atom_y for d in self.data] @property def atom_Y(self) -> list[np.ndarray]: """the (scaled) atom targets of the dataset""" return self.__atom_Y @atom_Y.setter def atom_Y(self, atom_Y: list[np.ndarray]): self._validate_attribute(atom_Y, "atom targets") self.__atom_Y = atom_Y @cached_property def _atom_constraints(self) -> np.ndarray: return np.array([d.atom_constraint for d in self.data]) @property def atom_constraints(self) -> np.ndarray: return self.__atom_constraints @atom_constraints.setter def atom_constraints(self, atom_constraints: ArrayLike): self._validate_attribute(atom_constraints, "atom constraints") self.__atom_constraints = np.array(atom_constraints) @property def _bond_Y(self) -> list[np.ndarray]: """the raw bond targets of the dataset""" return [d.bond_y for d in self.data] @property def bond_Y(self) -> list[np.ndarray]: """the (scaled) bond targets of the dataset""" return self.__bond_Y @bond_Y.setter def bond_Y(self, bond_Y: list[np.ndarray]): self._validate_attribute(bond_Y, "bond targets") self.__bond_Y = bond_Y @cached_property def _bond_constraints(self) -> np.ndarray: return np.array([d.bond_constraint for d in self.data]) @property def bond_constraints(self) -> np.ndarray: return self.__bond_constraints @bond_constraints.setter def bond_constraints(self, bond_constraints: ArrayLike): self._validate_attribute(bond_constraints, "bond constraints") self.__bond_constraints = np.array(bond_constraints) @property def atom_gt_mask(self) -> np.ndarray: return np.vstack([d.atom_gt_mask for d in self.data]) @property def atom_lt_mask(self) -> np.ndarray: return np.vstack([d.atom_lt_mask for d in self.data]) @property def bond_gt_mask(self) -> np.ndarray: return np.vstack([d.bond_gt_mask for d in self.data]) @property def bond_lt_mask(self) -> np.ndarray: return np.vstack([d.bond_lt_mask for d in self.data]) @property def _E_ds(self) -> list[np.ndarray]: """the raw bond descriptors of the dataset""" return [d.E_d for d in self.data] @property def E_ds(self) -> list[np.ndarray]: """the (scaled) bond descriptors of the dataset""" return self.__E_ds @E_ds.setter def E_ds(self, E_ds: list[np.ndarray]): self._validate_attribute(E_ds, "bond descriptors") self.__E_ds = E_ds @property def d_ed(self) -> int: """the extra bond descriptor dimension, if any""" return 0 if self.E_ds[0] is None else self.E_ds[0].shape[1]
[docs] def normalize_targets( self, key: str = "mol", scaler: StandardScaler | None = None ) -> StandardScaler: VALID_KEYS = {"mol", "atom", "bond"} match key: case "mol": X = self._Y case "atom": X = np.concatenate(self._atom_Y, axis=0) case "bond": X = np.concatenate(self._bond_Y, axis=0) case _: raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") if scaler is None: scaler = StandardScaler().fit(X) match key: case "mol": self.Y = scaler.transform(X) case "atom": self.atom_Y = [scaler.transform(y) if y.size > 0 else y for y in self._atom_Y] if self.atom_constraints[0] is not None: atoms_per_mol = [len(d.atom_y) for d in self.data] scaled_atom_constraints = [ (row - n * scaler.mean_) / scaler.scale_ for row, n in zip(self._atom_constraints, atoms_per_mol) ] self.atom_constraints = np.array(scaled_atom_constraints) case "bond": self.bond_Y = [scaler.transform(y) if y.size > 0 else y for y in self._bond_Y] if self.bond_constraints[0] is not None: bonds_per_mol = [len(d.bond_y) for d in self.data] scaled_bond_constraints = [ (row - n * scaler.mean_) / scaler.scale_ for row, n in zip(self._bond_constraints, bonds_per_mol) ] self.bond_constraints = np.array(scaled_bond_constraints) case _: raise RuntimeError("unreachable code reached!") return scaler
[docs] def normalize_inputs( self, key: str = "X_d", scaler: StandardScaler | None = None ) -> StandardScaler: VALID_KEYS = {"X_d", "V_f", "E_f", "V_d", "E_d"} match key: case "X_d": X = None if self.d_xd == 0 else self.X_d case "V_f": X = None if self.d_vf == 0 else np.concatenate(self._V_fs, axis=0) case "E_f": X = None if self.d_ef == 0 else np.concatenate(self._E_fs, axis=0) case "V_d": X = None if self.d_vd == 0 else np.concatenate(self._V_ds, axis=0) case "E_d": X = None if self.d_ed == 0 else np.concatenate(self._E_ds, axis=0) case _: raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") if X is None: return scaler if scaler is None: scaler = StandardScaler().fit(X) match key: case "X_d": self.X_d = scaler.transform(X) case "V_f": self.V_fs = [scaler.transform(V_f) if V_f.size > 0 else V_f for V_f in self._V_fs] case "E_f": self.E_fs = [scaler.transform(E_f) if E_f.size > 0 else E_f for E_f in self._E_fs] case "V_d": self.V_ds = [scaler.transform(V_d) if V_d.size > 0 else V_d for V_d in self._V_ds] case "E_d": self.E_ds = [scaler.transform(E_d) if E_d.size > 0 else E_d for E_d in self._E_ds] case _: raise RuntimeError("unreachable code reached!") return scaler
[docs] def reset(self): """Reset the atom and bond features; atom and extra descriptors; and targets of each datapoint to their initial, unnormalized values.""" super().reset() self.__E_ds = self._E_ds self.__atom_Y = self._atom_Y self.__bond_Y = self._bond_Y self.__atom_constraints = self._atom_constraints self.__bond_constraints = self._bond_constraints
[docs] @dataclass class ReactionDataset(_MolGraphDatasetMixin, MolGraphDataset): r"""A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s .. note:: The featurized data provided by this class may be cached, simlar to a :class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset cache=True``. """ data: list[ReactionDatapoint] """the dataset from which to load""" featurizer: Featurizer[Rxn, MolGraph] = field(default_factory=CGRFeaturizer) """the featurizer with which to generate MolGraphs of the input""" n_workers: int = 0 """number of workers to use for cache calculation"""
[docs] def __post_init__(self): if self.data is None: raise ValueError("Data cannot be None!") self.reset() self.cache = False
@property def cache(self) -> bool: return self.__cache @cache.setter def cache(self, cache: bool = False): self.__cache = cache if cache: self.mg_cache = MolGraphCache( self.mols, [None] * len(self), [None] * len(self), self.featurizer, n_workers=self.n_workers, ) else: self.mg_cache = MolGraphCacheOnTheFly( self.mols, [None] * len(self), [None] * len(self), self.featurizer )
[docs] def __getitem__(self, idx: int) -> Datum: d = self.data[idx] mg = self.mg_cache[idx] return Datum(mg, None, self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
@property def smiles(self) -> list[tuple]: return [(Chem.MolToSmiles(d.rct), Chem.MolToSmiles(d.pdt)) for d in self.data] @property def mols(self) -> list[Rxn]: return [(d.rct, d.pdt) for d in self.data] @property def d_vf(self) -> int: return 0 @property def d_ef(self) -> int: return 0 @property def d_vd(self) -> int: return 0
[docs] @dataclass(repr=False, eq=False) class MulticomponentDataset(_MolGraphDatasetMixin, Dataset): r"""A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel :class:`MoleculeDatasets` and :class:`ReactionDataset`\s""" datasets: list[MoleculeDataset | ReactionDataset] """the parallel datasets"""
[docs] def __post_init__(self): sizes = [len(dset) for dset in self.datasets] if not all(sizes[0] == size for size in sizes[1:]): raise ValueError(f"Datasets must have all same length! got: {sizes}")
[docs] def __len__(self) -> int: return len(self.datasets[0])
@property def n_components(self) -> int: return len(self.datasets)
[docs] def __getitem__(self, idx: int) -> list[Datum]: return [dset[idx] for dset in self.datasets]
@property def smiles(self) -> list[list[str]]: return list(zip(*[dset.smiles for dset in self.datasets])) @property def names(self) -> list[list[str]]: return list(zip(*[dset.names for dset in self.datasets])) @property def mols(self) -> list[list[Chem.Mol]]: return list(zip(*[dset.mols for dset in self.datasets]))
[docs] def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler: return self.datasets[0].normalize_targets(scaler)
[docs] def normalize_inputs( self, key: str = "X_d", scaler: list[StandardScaler] | None = None ) -> list[StandardScaler]: RXN_VALID_KEYS = {"X_d"} match scaler: case None: return [ dset.normalize_inputs(key) if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS else None for dset in self.datasets ] case _: assert len(scaler) == len( self.datasets ), "Number of scalers must match number of datasets!" return [ dset.normalize_inputs(key, s) if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS else None for dset, s in zip(self.datasets, scaler) ]
[docs] def reset(self): return [dset.reset() for dset in self.datasets]
@property def d_xd(self) -> list[int]: return self.datasets[0].d_xd @property def d_vf(self) -> list[int]: return sum(dset.d_vf for dset in self.datasets) @property def d_ef(self) -> list[int]: return sum(dset.d_ef for dset in self.datasets) @property def d_vd(self) -> list[int]: return sum(dset.d_vd for dset in self.datasets)