from dataclasses import dataclass, field
from functools import cached_property
import logging
from typing import NamedTuple, TypeAlias
import numpy as np
from numpy.typing import ArrayLike
from rdkit import Chem
from rdkit.Chem import Mol
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
from chemprop.data.datapoints import (
LazyMoleculeDatapoint,
MolAtomBondDatapoint,
MoleculeDatapoint,
ReactionDatapoint,
)
from chemprop.data.molgraph import MolGraph
from chemprop.featurizers.base import Featurizer
from chemprop.featurizers.molgraph import (
BatchCuikMolGraph,
CGRFeaturizer,
CuikmolmakerMolGraphFeaturizer,
SimpleMoleculeMolGraphFeaturizer,
)
from chemprop.featurizers.molgraph.cache import MolGraphCache, MolGraphCacheOnTheFly
from chemprop.types import Rxn
logger = logging.getLogger(__name__)
[docs]
class Datum(NamedTuple):
"""a singular training data point"""
mg: MolGraph
V_d: np.ndarray | None
x_d: np.ndarray | None
y: np.ndarray | None
weight: float
lt_mask: np.ndarray | None
gt_mask: np.ndarray | None
[docs]
class MolAtomBondDatum(NamedTuple):
"""a singular training data point that supports atom and bond level targets"""
mg: MolGraph
V_d: np.ndarray | None
E_d: np.ndarray | None
x_d: np.ndarray | None
ys: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]
weight: float
lt_masks: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]
gt_masks: tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]
constraints: tuple[np.ndarray | None, np.ndarray | None]
[docs]
class CuikBatchedDatum(NamedTuple):
"""a cuik-molmaker batch of data points"""
bmg: BatchCuikMolGraph
V_d: np.ndarray
X_d: np.ndarray
Y: np.ndarray
weights: np.ndarray
lt_mask: np.ndarray
gt_mask: np.ndarray
MolGraphDataset: TypeAlias = Dataset[Datum]
MolAtomBondGraphDataset: TypeAlias = Dataset[MolAtomBondDatum]
class _MolGraphDatasetMixin:
def __len__(self) -> int:
return len(self.data)
@cached_property
def _Y(self) -> np.ndarray:
"""the raw targets of the dataset"""
return np.array([d.y for d in self.data], float)
@property
def Y(self) -> np.ndarray:
"""the (scaled) targets of the dataset"""
return self.__Y
@Y.setter
def Y(self, Y: ArrayLike):
self._validate_attribute(Y, "targets")
self.__Y = np.array(Y, float)
@cached_property
def _X_d(self) -> np.ndarray:
"""the raw extra descriptors of the dataset"""
return np.array([d.x_d for d in self.data])
@property
def X_d(self) -> np.ndarray:
"""the (scaled) extra descriptors of the dataset"""
return self.__X_d
@X_d.setter
def X_d(self, X_d: ArrayLike):
self._validate_attribute(X_d, "extra descriptors")
self.__X_d = np.array(X_d)
@property
def weights(self) -> np.ndarray:
return np.array([d.weight for d in self.data])
@property
def gt_mask(self) -> np.ndarray:
return np.array([d.gt_mask for d in self.data])
@property
def lt_mask(self) -> np.ndarray:
return np.array([d.lt_mask for d in self.data])
@property
def t(self) -> int | None:
return self.data[0].t if len(self.data) > 0 else None
@property
def d_xd(self) -> int:
"""the extra molecule descriptor dimension, if any"""
return 0 if self.X_d[0] is None else self.X_d.shape[1]
@property
def names(self) -> list[str]:
return [d.name for d in self.data]
def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
"""Normalizes the targets of this dataset using a :obj:`StandardScaler`
The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for
each task independently. NOTE: This should only be used for regression datasets.
Returns
-------
StandardScaler
a scaler fit to the targets.
"""
if scaler is None:
scaler = StandardScaler().fit(self._Y)
self.Y = scaler.transform(self._Y)
return scaler
def normalize_inputs(
self, key: str = "X_d", scaler: StandardScaler | None = None
) -> StandardScaler:
VALID_KEYS = {"X_d"}
if key not in VALID_KEYS:
raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
X = self._X_d if self._X_d[0] is not None else None
if X is None:
return scaler
if scaler is None:
scaler = StandardScaler().fit(X)
self.X_d = scaler.transform(X)
return scaler
def reset(self):
"""Reset the atom and bond features; atom and extra descriptors; and targets of each
datapoint to their initial, unnormalized values."""
self.__Y = self._Y
self.__X_d = self._X_d
def _validate_attribute(self, X: np.ndarray, label: str):
if not len(self.data) == len(X):
raise ValueError(
f"number of molecules ({len(self.data)}) and {label} ({len(X)}) "
"must have same length!"
)
[docs]
@dataclass
class MoleculeDataset(_MolGraphDatasetMixin, MolGraphDataset):
r"""A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s
A :class:`MoleculeDataset` produces featurized data for input to a
:class:`MPNN` model. Typically, data featurization is performed on-the-fly
and parallelized across multiple workers via the :class:`~torch.utils.data
DataLoader` class. However, for small datasets, it may be more efficient to
featurize the data in advance and cache the results. This can be done by
setting ``MoleculeDataset.cache=True``.
Parameters
----------
data : Iterable[MoleculeDatapoint]
the data from which to create a dataset
featurizer : MoleculeFeaturizer
the featurizer with which to generate MolGraphs of the molecules
n_workers : int, optional
number of workers to use for cache calculation
"""
data: list[MoleculeDatapoint]
featurizer: Featurizer[Mol, MolGraph] = field(default_factory=SimpleMoleculeMolGraphFeaturizer)
n_workers: int = 0
[docs]
def __post_init__(self):
if self.data is None:
raise ValueError("Data cannot be None!")
self.reset()
self.cache = False
[docs]
def __getitem__(self, idx: int) -> Datum:
d = self.data[idx]
mg = self.mg_cache[idx]
return Datum(mg, self.V_ds[idx], self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
@property
def cache(self) -> bool:
return self.__cache
@cache.setter
def cache(self, cache: bool = False):
self.__cache = cache
self._init_cache()
def _init_cache(self):
"""initialize the cache"""
if self.cache:
self.mg_cache = MolGraphCache(
self.mols, self.V_fs, self.E_fs, self.featurizer, n_workers=self.n_workers
)
else:
self.mg_cache = MolGraphCacheOnTheFly(self.mols, self.V_fs, self.E_fs, self.featurizer)
@property
def smiles(self) -> list[str]:
"""the SMILES strings associated with the dataset"""
return [Chem.MolToSmiles(d.mol) for d in self.data]
@property
def mols(self) -> list[Chem.Mol]:
"""the molecules associated with the dataset"""
return [d.mol for d in self.data]
@property
def _V_fs(self) -> list[np.ndarray]:
"""the raw atom features of the dataset"""
return [d.V_f for d in self.data]
@property
def V_fs(self) -> list[np.ndarray]:
"""the (scaled) atom descriptors of the dataset"""
return self.__V_fs
@V_fs.setter
def V_fs(self, V_fs: list[np.ndarray]):
"""the (scaled) atom features of the dataset"""
self._validate_attribute(V_fs, "atom features")
self.__V_fs = V_fs
self._init_cache()
@property
def _E_fs(self) -> list[np.ndarray]:
"""the raw bond features of the dataset"""
return [d.E_f for d in self.data]
@property
def E_fs(self) -> list[np.ndarray]:
"""the (scaled) bond features of the dataset"""
return self.__E_fs
@E_fs.setter
def E_fs(self, E_fs: list[np.ndarray]):
self._validate_attribute(E_fs, "bond features")
self.__E_fs = E_fs
self._init_cache()
@property
def _V_ds(self) -> list[np.ndarray]:
"""the raw atom descriptors of the dataset"""
return [d.V_d for d in self.data]
@property
def V_ds(self) -> list[np.ndarray]:
"""the (scaled) atom descriptors of the dataset"""
return self.__V_ds
@V_ds.setter
def V_ds(self, V_ds: list[np.ndarray]):
self._validate_attribute(V_ds, "atom descriptors")
self.__V_ds = V_ds
@property
def d_vf(self) -> int:
"""the extra atom feature dimension, if any"""
return 0 if self.V_fs[0] is None else self.V_fs[0].shape[1]
@property
def d_ef(self) -> int:
"""the extra bond feature dimension, if any"""
return 0 if self.E_fs[0] is None else self.E_fs[0].shape[1]
@property
def d_vd(self) -> int:
"""the extra atom descriptor dimension, if any"""
return 0 if self.V_ds[0] is None else self.V_ds[0].shape[1]
[docs]
def reset(self):
"""Reset the atom and bond features; atom and extra descriptors; and targets of each
datapoint to their initial, unnormalized values."""
super().reset()
self.__V_fs = self._V_fs
self.__E_fs = self._E_fs
self.__V_ds = self._V_ds
[docs]
@dataclass
class CuikmolmakerDataset(MoleculeDataset):
r"""A :class:`CuikmolmakerDataset` composed of :class:`LazyMoleculeDatapoint`\s and a
:class:`CuikmolmakerMolGraphFeaturizer`
A :class:`CuikmolmakerDataset` produces featurized data for a batch of molecules for ingestion
by a :class:`MPNN` model. Data featurization is always performed on-the-fly and using the
cuik-molmaker package. This batched processing is significantly faster and consumes less memory
than the default featurization method when caching is not possible.
Parameters
----------
data : Iterable[LazyMoleculeDatapoint]
the data from which to create a dataset
featurizer : CuikmolmakerMolGraphFeaturizer
the featurizer with which to generate MolGraphs of the molecules
"""
data: list[LazyMoleculeDatapoint]
featurizer: CuikmolmakerMolGraphFeaturizer = field(
default_factory=CuikmolmakerMolGraphFeaturizer
)
@MoleculeDataset.cache.setter
def cache(self, cache: bool = False):
if cache:
raise NotImplementedError("CuikmolmakerDataset is meant to be used without caching!")
def _init_cache(self):
pass
@property
def smiles(self) -> list[str]:
return [d.smiles for d in self.data]
[docs]
def __getitem__(self, idx: int) -> Datum:
d = self.data[idx]
bmg = self.featurizer([d.smiles], self.V_fs[idx], self.E_fs[idx])
mg = MolGraph(
bmg.V.numpy(), bmg.E.numpy(), bmg.edge_index.numpy(), bmg.rev_edge_index.numpy()
)
return Datum(mg, self.V_ds[idx], self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
[docs]
def __getitems__(self, indexes: list[int]) -> CuikBatchedDatum:
smiles_list = [self.data[idx].smiles for idx in indexes]
V_f = np.concat([self.V_fs[idx] for idx in indexes]) if self.V_fs[0] is not None else None
E_f = np.concat([self.E_fs[idx] for idx in indexes]) if self.E_fs[0] is not None else None
bmg = self.featurizer(smiles_list, V_f, E_f)
V_d = np.concat([self.V_ds[idx] for idx in indexes]) if self.V_ds[0] is not None else None
X_d = self.X_d[indexes] if self.X_d[0] is not None else None
Y = self.Y[indexes] if self.Y[0] is not None else None
weights = self.weights[indexes]
lt_mask = self.lt_mask[indexes] if self.lt_mask[0] is not None else None
gt_mask = self.gt_mask[indexes] if self.gt_mask[0] is not None else None
return CuikBatchedDatum(bmg, V_d, X_d, Y, weights, lt_mask, gt_mask)
[docs]
class MolAtomBondDataset(MoleculeDataset, MolAtomBondGraphDataset):
data: list[MolAtomBondDatapoint]
[docs]
def __getitem__(self, idx: int) -> MolAtomBondDatum:
d = self.data[idx]
mg = self.mg_cache[idx]
return MolAtomBondDatum(
mg,
self.V_ds[idx],
self.E_ds[idx],
self.X_d[idx],
[
self.Y[idx] if isinstance(self.Y[idx], np.ndarray) else None,
self.atom_Y[idx] if self.atom_Y is not None else None,
self.bond_Y[idx] if self.bond_Y is not None else None,
],
d.weight,
[d.lt_mask, d.atom_lt_mask, d.bond_lt_mask],
[d.gt_mask, d.atom_gt_mask, d.bond_gt_mask],
[d.atom_constraint, d.bond_constraint],
)
@property
def _atom_Y(self) -> list[np.ndarray]:
"""the raw atom targets of the dataset"""
return [d.atom_y for d in self.data]
@property
def atom_Y(self) -> list[np.ndarray]:
"""the (scaled) atom targets of the dataset"""
return self.__atom_Y
@atom_Y.setter
def atom_Y(self, atom_Y: list[np.ndarray]):
self._validate_attribute(atom_Y, "atom targets")
self.__atom_Y = atom_Y
@cached_property
def _atom_constraints(self) -> np.ndarray:
return np.array([d.atom_constraint for d in self.data])
@property
def atom_constraints(self) -> np.ndarray:
return self.__atom_constraints
@atom_constraints.setter
def atom_constraints(self, atom_constraints: ArrayLike):
self._validate_attribute(atom_constraints, "atom constraints")
self.__atom_constraints = np.array(atom_constraints)
@property
def _bond_Y(self) -> list[np.ndarray]:
"""the raw bond targets of the dataset"""
return [d.bond_y for d in self.data]
@property
def bond_Y(self) -> list[np.ndarray]:
"""the (scaled) bond targets of the dataset"""
return self.__bond_Y
@bond_Y.setter
def bond_Y(self, bond_Y: list[np.ndarray]):
self._validate_attribute(bond_Y, "bond targets")
self.__bond_Y = bond_Y
@cached_property
def _bond_constraints(self) -> np.ndarray:
return np.array([d.bond_constraint for d in self.data])
@property
def bond_constraints(self) -> np.ndarray:
return self.__bond_constraints
@bond_constraints.setter
def bond_constraints(self, bond_constraints: ArrayLike):
self._validate_attribute(bond_constraints, "bond constraints")
self.__bond_constraints = np.array(bond_constraints)
@property
def atom_gt_mask(self) -> np.ndarray:
return np.vstack([d.atom_gt_mask for d in self.data])
@property
def atom_lt_mask(self) -> np.ndarray:
return np.vstack([d.atom_lt_mask for d in self.data])
@property
def bond_gt_mask(self) -> np.ndarray:
return np.vstack([d.bond_gt_mask for d in self.data])
@property
def bond_lt_mask(self) -> np.ndarray:
return np.vstack([d.bond_lt_mask for d in self.data])
@property
def _E_ds(self) -> list[np.ndarray]:
"""the raw bond descriptors of the dataset"""
return [d.E_d for d in self.data]
@property
def E_ds(self) -> list[np.ndarray]:
"""the (scaled) bond descriptors of the dataset"""
return self.__E_ds
@E_ds.setter
def E_ds(self, E_ds: list[np.ndarray]):
self._validate_attribute(E_ds, "bond descriptors")
self.__E_ds = E_ds
@property
def d_ed(self) -> int:
"""the extra bond descriptor dimension, if any"""
return 0 if self.E_ds[0] is None else self.E_ds[0].shape[1]
[docs]
def normalize_targets(
self, key: str = "mol", scaler: StandardScaler | None = None
) -> StandardScaler:
VALID_KEYS = {"mol", "atom", "bond"}
match key:
case "mol":
X = self._Y
case "atom":
X = np.concatenate(self._atom_Y, axis=0)
case "bond":
X = np.concatenate(self._bond_Y, axis=0)
case _:
raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
if scaler is None:
scaler = StandardScaler().fit(X)
match key:
case "mol":
self.Y = scaler.transform(X)
case "atom":
self.atom_Y = [scaler.transform(y) if y.size > 0 else y for y in self._atom_Y]
if self.atom_constraints[0] is not None:
atoms_per_mol = [len(d.atom_y) for d in self.data]
scaled_atom_constraints = [
(row - n * scaler.mean_) / scaler.scale_
for row, n in zip(self._atom_constraints, atoms_per_mol)
]
self.atom_constraints = np.array(scaled_atom_constraints)
case "bond":
self.bond_Y = [scaler.transform(y) if y.size > 0 else y for y in self._bond_Y]
if self.bond_constraints[0] is not None:
bonds_per_mol = [len(d.bond_y) for d in self.data]
scaled_bond_constraints = [
(row - n * scaler.mean_) / scaler.scale_
for row, n in zip(self._bond_constraints, bonds_per_mol)
]
self.bond_constraints = np.array(scaled_bond_constraints)
case _:
raise RuntimeError("unreachable code reached!")
return scaler
[docs]
def reset(self):
"""Reset the atom and bond features; atom and extra descriptors; and targets of each
datapoint to their initial, unnormalized values."""
super().reset()
self.__E_ds = self._E_ds
self.__atom_Y = self._atom_Y
self.__bond_Y = self._bond_Y
self.__atom_constraints = self._atom_constraints
self.__bond_constraints = self._bond_constraints
[docs]
@dataclass
class ReactionDataset(_MolGraphDatasetMixin, MolGraphDataset):
r"""A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s
.. note::
The featurized data provided by this class may be cached, simlar to a
:class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset
cache=True``.
"""
data: list[ReactionDatapoint]
"""the dataset from which to load"""
featurizer: Featurizer[Rxn, MolGraph] = field(default_factory=CGRFeaturizer)
"""the featurizer with which to generate MolGraphs of the input"""
n_workers: int = 0
"""number of workers to use for cache calculation"""
[docs]
def __post_init__(self):
if self.data is None:
raise ValueError("Data cannot be None!")
self.reset()
self.cache = False
@property
def cache(self) -> bool:
return self.__cache
@cache.setter
def cache(self, cache: bool = False):
self.__cache = cache
if cache:
self.mg_cache = MolGraphCache(
self.mols,
[None] * len(self),
[None] * len(self),
self.featurizer,
n_workers=self.n_workers,
)
else:
self.mg_cache = MolGraphCacheOnTheFly(
self.mols, [None] * len(self), [None] * len(self), self.featurizer
)
[docs]
def __getitem__(self, idx: int) -> Datum:
d = self.data[idx]
mg = self.mg_cache[idx]
return Datum(mg, None, self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
@property
def smiles(self) -> list[tuple]:
return [(Chem.MolToSmiles(d.rct), Chem.MolToSmiles(d.pdt)) for d in self.data]
@property
def mols(self) -> list[Rxn]:
return [(d.rct, d.pdt) for d in self.data]
@property
def d_vf(self) -> int:
return 0
@property
def d_ef(self) -> int:
return 0
@property
def d_vd(self) -> int:
return 0
[docs]
@dataclass(repr=False, eq=False)
class MulticomponentDataset(_MolGraphDatasetMixin, Dataset):
r"""A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel
:class:`MoleculeDatasets` and :class:`ReactionDataset`\s"""
datasets: list[MoleculeDataset | ReactionDataset]
"""the parallel datasets"""
[docs]
def __post_init__(self):
sizes = [len(dset) for dset in self.datasets]
if not all(sizes[0] == size for size in sizes[1:]):
raise ValueError(f"Datasets must have all same length! got: {sizes}")
[docs]
def __len__(self) -> int:
return len(self.datasets[0])
@property
def n_components(self) -> int:
return len(self.datasets)
[docs]
def __getitem__(self, idx: int) -> list[Datum]:
return [dset[idx] for dset in self.datasets]
@property
def smiles(self) -> list[list[str]]:
return list(zip(*[dset.smiles for dset in self.datasets]))
@property
def names(self) -> list[list[str]]:
return list(zip(*[dset.names for dset in self.datasets]))
@property
def mols(self) -> list[list[Chem.Mol]]:
return list(zip(*[dset.mols for dset in self.datasets]))
[docs]
def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
return self.datasets[0].normalize_targets(scaler)
[docs]
def reset(self):
return [dset.reset() for dset in self.datasets]
@property
def d_xd(self) -> list[int]:
return self.datasets[0].d_xd
@property
def d_vf(self) -> list[int]:
return sum(dset.d_vf for dset in self.datasets)
@property
def d_ef(self) -> list[int]:
return sum(dset.d_ef for dset in self.datasets)
@property
def d_vd(self) -> list[int]:
return sum(dset.d_vd for dset in self.datasets)