Source code for chemprop.featurizers.molgraph.cache

from abc import abstractmethod
from collections.abc import Sequence
from typing import Generic, Iterable

import numpy as np

from chemprop.data.molgraph import MolGraph
from chemprop.featurizers.base import Featurizer, S
from chemprop.utils import parallel_execute


[docs] class MolGraphCacheFacade(Sequence[MolGraph], Generic[S]): r""" A :class:`MolGraphCacheFacade` provided an interface for caching :class:`~chemprop.data.molgraph.MolGraph`\s. .. note:: This class only provides a facade for a cached dataset, but it *does not guarantee* whether the underlying data is truly cached. Parameters ---------- inputs : Iterable[S] The inputs to be featurized. V_fs : Iterable[np.ndarray] The node features for each input. E_fs : Iterable[np.ndarray] The edge features for each input. featurizer : Featurizer[S, MolGraph] The featurizer with which to generate the :class:`~chemprop.data.molgraph.MolGraph`\s. """ @abstractmethod def __init__( self, inputs: Iterable[S], V_fs: Iterable[np.ndarray], E_fs: Iterable[np.ndarray], featurizer: Featurizer[S, MolGraph], ): pass
[docs] class MolGraphCache(MolGraphCacheFacade): r""" A :class:`MolGraphCache` precomputes the corresponding :class:`~chemprop.data.molgraph.MolGraph`\s and caches them in memory. """ def __init__( self, inputs: Iterable[S], V_fs: Iterable[np.ndarray | None], E_fs: Iterable[np.ndarray | None], featurizer: Featurizer[S, MolGraph], n_workers: int = 0, ): self._mgs = parallel_execute(featurizer, zip(inputs, V_fs, E_fs), n_workers=n_workers)
[docs] def __len__(self) -> int: return len(self._mgs)
[docs] def __getitem__(self, index: int) -> MolGraph: return self._mgs[index]
[docs] class MolGraphCacheOnTheFly(MolGraphCacheFacade): r""" A :class:`MolGraphCacheOnTheFly` computes the corresponding :class:`~chemprop.data.molgraph.MolGraph`\s as they are requested. """ def __init__( self, inputs: Iterable[S], V_fs: Iterable[np.ndarray | None], E_fs: Iterable[np.ndarray | None], featurizer: Featurizer[S, MolGraph], ): self._inputs = list(inputs) self._V_fs = list(V_fs) self._E_fs = list(E_fs) self._featurizer = featurizer
[docs] def __len__(self) -> int: return len(self._inputs)
[docs] def __getitem__(self, index: int) -> MolGraph: return self._featurizer(self._inputs[index], self._V_fs[index], self._E_fs[index])