chemprop.data
=============

.. py:module:: chemprop.data


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/chemprop/data/collate/index
   /autoapi/chemprop/data/dataloader/index
   /autoapi/chemprop/data/datapoints/index
   /autoapi/chemprop/data/datasets/index
   /autoapi/chemprop/data/molgraph/index
   /autoapi/chemprop/data/samplers/index
   /autoapi/chemprop/data/splitting/index


Attributes
----------

.. autoapisummary::

   chemprop.data.MolGraphDataset


Classes
-------

.. autoapisummary::

   chemprop.data.BatchMolAtomBondGraph
   chemprop.data.BatchMolGraph
   chemprop.data.MolAtomBondTrainingBatch
   chemprop.data.MulticomponentTrainingBatch
   chemprop.data.TrainingBatch
   chemprop.data.LazyMoleculeDatapoint
   chemprop.data.MolAtomBondDatapoint
   chemprop.data.MoleculeDatapoint
   chemprop.data.ReactionDatapoint
   chemprop.data.CuikmolmakerDataset
   chemprop.data.Datum
   chemprop.data.MolAtomBondDataset
   chemprop.data.MolAtomBondDatum
   chemprop.data.MoleculeDataset
   chemprop.data.MulticomponentDataset
   chemprop.data.ReactionDataset
   chemprop.data.MolGraph
   chemprop.data.ClassBalanceSampler
   chemprop.data.SeededSampler
   chemprop.data.SplitType


Functions
---------

.. autoapisummary::

   chemprop.data.collate_batch
   chemprop.data.collate_mol_atom_bond_batch
   chemprop.data.collate_multicomponent
   chemprop.data.build_dataloader
   chemprop.data.make_split_indices
   chemprop.data.split_data_by_indices


Package Contents
----------------

.. py:class:: BatchMolAtomBondGraph

   Bases: :py:obj:`BatchMolGraph`


   A :class:`BatchMolGraph` represents a batch of individual :class:`MolGraph`\s.

   It has all the attributes of a ``MolGraph`` with the addition of the ``batch`` attribute. This
   class is intended for use with data loading, so it uses :obj:`~torch.Tensor`\s to store data


   .. py:attribute:: bond_batch
      :type:  torch.Tensor

      A tensor of indices that show which :class:`MolGraph` each bond belongs to in the batch


   .. py:method:: __post_init__(mgs)


   .. py:method:: to(device)


.. py:class:: BatchMolGraph

   A :class:`BatchMolGraph` represents a batch of individual :class:`MolGraph`\s.

   It has all the attributes of a ``MolGraph`` with the addition of the ``batch`` attribute. This
   class is intended for use with data loading, so it uses :obj:`~torch.Tensor`\s to store data


   .. py:attribute:: mgs
      :type:  dataclasses.InitVar[Sequence[chemprop.data.molgraph.MolGraph]]

      A list of individual :class:`MolGraph`\s to be batched together


   .. py:attribute:: V
      :type:  torch.Tensor

      the atom feature matrix


   .. py:attribute:: E
      :type:  torch.Tensor

      the bond feature matrix


   .. py:attribute:: edge_index
      :type:  torch.Tensor

      an tensor of shape ``2 x E`` containing the edges of the graph in COO format


   .. py:attribute:: rev_edge_index
      :type:  torch.Tensor

      A tensor of shape ``E`` that maps from an edge index to the index of the source of the
      reverse edge in the ``edge_index`` attribute.


   .. py:attribute:: batch
      :type:  torch.Tensor

      the index of the parent :class:`MolGraph` in the batched graph


   .. py:method:: __post_init__(mgs)


   .. py:method:: __len__()

      the number of individual :class:`MolGraph`\s in this batch



   .. py:method:: to(device)


.. py:class:: MolAtomBondTrainingBatch

   Bases: :py:obj:`NamedTuple`


   .. py:attribute:: bmg
      :type:  BatchMolAtomBondGraph


   .. py:attribute:: V_d
      :type:  torch.Tensor | None


   .. py:attribute:: E_d
      :type:  torch.Tensor | None


   .. py:attribute:: X_d
      :type:  torch.Tensor | None


   .. py:attribute:: Ys
      :type:  tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]


   .. py:attribute:: w
      :type:  tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]


   .. py:attribute:: lt_masks
      :type:  tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]


   .. py:attribute:: gt_masks
      :type:  tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]


   .. py:attribute:: constraints
      :type:  tuple[torch.Tensor | None, torch.Tensor | None]


.. py:class:: MulticomponentTrainingBatch

   Bases: :py:obj:`NamedTuple`


   .. py:attribute:: bmgs
      :type:  list[BatchMolGraph]


   .. py:attribute:: V_ds
      :type:  list[torch.Tensor | None]


   .. py:attribute:: X_d
      :type:  torch.Tensor | None


   .. py:attribute:: Y
      :type:  torch.Tensor | None


   .. py:attribute:: w
      :type:  torch.Tensor


   .. py:attribute:: lt_mask
      :type:  torch.Tensor | None


   .. py:attribute:: gt_mask
      :type:  torch.Tensor | None


.. py:class:: TrainingBatch

   Bases: :py:obj:`NamedTuple`


   .. py:attribute:: bmg
      :type:  BatchMolGraph | chemprop.featurizers.molgraph.molecule.BatchCuikMolGraph


   .. py:attribute:: V_d
      :type:  torch.Tensor | None


   .. py:attribute:: X_d
      :type:  torch.Tensor | None


   .. py:attribute:: Y
      :type:  torch.Tensor | None


   .. py:attribute:: w
      :type:  torch.Tensor


   .. py:attribute:: lt_mask
      :type:  torch.Tensor | None


   .. py:attribute:: gt_mask
      :type:  torch.Tensor | None


.. py:function:: collate_batch(batch)

.. py:function:: collate_mol_atom_bond_batch(batch)

.. py:function:: collate_multicomponent(batches)

.. py:function:: build_dataloader(dataset, batch_size = 64, num_workers = 0, class_balance = False, seed = None, shuffle = True, drop_last = None, **kwargs)

   Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s

   :param dataset: The dataset containing the molecules or reactions to load.
   :type dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset
   :param batch_size: the batch size to load.
   :type batch_size: int, default=64
   :param num_workers: the number of workers used to build batches.
   :type num_workers: int, default=0
   :param class_balance: Whether to perform class balancing (i.e., use an equal number of positive and negative
                         molecules). Class balance is only available for single task classification datasets. Set
                         shuffle to True in order to get a random subset of the larger class.
   :type class_balance: bool, default=False
   :param seed: the random seed to use for shuffling (only used when `shuffle` is `True`).
   :type seed: int, default=None
   :param shuffle: whether to shuffle the data during sampling.
   :type shuffle: bool, default=True
   :param drop_last: Whether to drop the last batch if it is of size 1 (needed if using batchnorm during training).
                     If None, this will be set automatically.
   :type drop_last: bool, default=None


.. py:class:: LazyMoleculeDatapoint

   Bases: :py:obj:`_DatapointMixin`, :py:obj:`_LazyMoleculeDatapointMixin`


   A :class:`LazyMoleculeDatapoint` contains a single SMILES string, and all attributes need to
   form a `rdkit.Chem.Mol` object. The molecule is computed lazily when the attribute `mol` is accessed.


   .. py:attribute:: V_f
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``V x d_vf``, where ``V`` is the number of atoms in the molecule, and
      ``d_vf`` is the number of additional features that will be concatenated to atom-level features
      *before* message passing


   .. py:attribute:: E_f
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``E x d_ef``, where ``E`` is the number of bonds in the molecule, and
      ``d_ef`` is the number of additional features  containing additional features that will be
      concatenated to bond-level features *before* message passing


   .. py:attribute:: V_d
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``V x d_vd``, where ``V`` is the number of atoms in the molecule, and
      ``d_vd`` is the number of additional descriptors that will be concatenated to atom-level
      descriptors *after* message passing


   .. py:method:: __post_init__()


   .. py:method:: __len__()


.. py:class:: MolAtomBondDatapoint

   Bases: :py:obj:`MoleculeDatapoint`


   A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets.


   .. py:attribute:: E_d
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``E x d_ed``, where ``E`` is the number of bonds in the molecule, and
      ``d_ed`` is the number of additional descriptors that will be concatenated to edge-level
      descriptors *after* message passing


   .. py:attribute:: atom_y
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``V x v_t``, where ``V`` is the number of atoms in the molecule, and
      ``v_t`` is the number of atom targets. The order of atoms in the array should match the order of
      atoms in the mol. Unknown targets are indicated by `nan`s.


   .. py:attribute:: atom_gt_mask
      :type:  numpy.ndarray | None
      :value: None


      Indicates whether the atom targets are an inequality regression target of the form `<x`


   .. py:attribute:: atom_lt_mask
      :type:  numpy.ndarray | None
      :value: None


      Indicates whether the atom targets are an inequality regression target of the form `>x`


   .. py:attribute:: bond_y
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``E x e_t``, where ``V`` is the number of bonds in the molecule, and
      ``e_t`` is the number of bond targets. The order of bonds in the array should match the order of
      bonds in the mol. Unknown targets are indicated by `nan`s.


   .. py:attribute:: bond_gt_mask
      :type:  numpy.ndarray | None
      :value: None


      Indicates whether the bond targets are an inequality regression target of the form `<x`


   .. py:attribute:: bond_lt_mask
      :type:  numpy.ndarray | None
      :value: None


      Indicates whether the bond targets are an inequality regression target of the form `>x`


   .. py:attribute:: atom_constraint
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``1 x v_t`` containing the values that the atom property predictions
      should be constrained to sum to, with np.nan indicating no constraint for that property


   .. py:attribute:: bond_constraint
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``1 x e_t`` containing the values that the bond property predictions
      should be constrained to sum to, with np.nan indicating no constraint for that property


   .. py:method:: __post_init__()


   .. py:method:: from_smi(smi, *args, keep_h = False, add_h = False, ignore_stereo = False, reorder_atoms = True, **kwargs)
      :classmethod:



.. py:class:: MoleculeDatapoint

   Bases: :py:obj:`_DatapointMixin`, :py:obj:`_MoleculeDatapointMixin`


   A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets.


   .. py:attribute:: V_f
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``V x d_vf``, where ``V`` is the number of atoms in the molecule, and
      ``d_vf`` is the number of additional features that will be concatenated to atom-level features
      *before* message passing


   .. py:attribute:: E_f
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``E x d_ef``, where ``E`` is the number of bonds in the molecule, and
      ``d_ef`` is the number of additional features  containing additional features that will be
      concatenated to bond-level features *before* message passing


   .. py:attribute:: V_d
      :type:  numpy.ndarray | None
      :value: None


      A numpy array of shape ``V x d_vd``, where ``V`` is the number of atoms in the molecule, and
      ``d_vd`` is the number of additional descriptors that will be concatenated to atom-level
      descriptors *after* message passing


   .. py:method:: __post_init__()


   .. py:method:: __len__()


.. py:class:: ReactionDatapoint

   Bases: :py:obj:`_DatapointMixin`, :py:obj:`_ReactionDatapointMixin`


   A :class:`ReactionDatapoint` contains a single reaction and its associated features and targets.


   .. py:method:: __post_init__()


   .. py:method:: __len__()


.. py:class:: CuikmolmakerDataset

   Bases: :py:obj:`MoleculeDataset`


   A :class:`CuikmolmakerDataset` composed of :class:`LazyMoleculeDatapoint`\s and a
   :class:`CuikmolmakerMolGraphFeaturizer`

   A :class:`CuikmolmakerDataset` produces featurized data for a batch of molecules for ingestion
   by a :class:`MPNN` model. Data featurization is always performed on-the-fly and using the
   cuik-molmaker package. This batched processing is significantly faster and consumes less memory
   than the default featurization method when caching is not possible.

   :param data: the data from which to create a dataset
   :type data: Iterable[LazyMoleculeDatapoint]
   :param featurizer: the featurizer with which to generate MolGraphs of the molecules
   :type featurizer: CuikmolmakerMolGraphFeaturizer


   .. py:attribute:: data
      :type:  list[chemprop.data.datapoints.LazyMoleculeDatapoint]


   .. py:attribute:: featurizer
      :type:  chemprop.featurizers.molgraph.CuikmolmakerMolGraphFeaturizer


   .. py:property:: smiles
      :type: list[str]


      the SMILES strings associated with the dataset


   .. py:method:: __getitem__(idx)


   .. py:method:: __getitems__(indexes)


.. py:class:: Datum

   Bases: :py:obj:`NamedTuple`


   a singular training data point


   .. py:attribute:: mg
      :type:  chemprop.data.molgraph.MolGraph


   .. py:attribute:: V_d
      :type:  numpy.ndarray | None


   .. py:attribute:: x_d
      :type:  numpy.ndarray | None


   .. py:attribute:: y
      :type:  numpy.ndarray | None


   .. py:attribute:: weight
      :type:  float


   .. py:attribute:: lt_mask
      :type:  numpy.ndarray | None


   .. py:attribute:: gt_mask
      :type:  numpy.ndarray | None


.. py:class:: MolAtomBondDataset

   Bases: :py:obj:`MoleculeDataset`, :py:obj:`MolAtomBondGraphDataset`


   A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s

   A :class:`MoleculeDataset` produces featurized data for input to a
   :class:`MPNN` model. Typically, data featurization is performed on-the-fly
   and parallelized across multiple workers via the :class:`~torch.utils.data
   DataLoader` class. However, for small datasets, it may be more efficient to
   featurize the data in advance and cache the results. This can be done by
   setting ``MoleculeDataset.cache=True``.

   :param data: the data from which to create a dataset
   :type data: Iterable[MoleculeDatapoint]
   :param featurizer: the featurizer with which to generate MolGraphs of the molecules
   :type featurizer: MoleculeFeaturizer
   :param n_workers: number of workers to use for cache calculation
   :type n_workers: int, optional


   .. py:attribute:: data
      :type:  list[chemprop.data.datapoints.MolAtomBondDatapoint]


   .. py:method:: __getitem__(idx)


   .. py:property:: atom_Y
      :type: list[numpy.ndarray]


      the (scaled) atom targets of the dataset


   .. py:property:: atom_constraints
      :type: numpy.ndarray



   .. py:property:: bond_Y
      :type: list[numpy.ndarray]


      the (scaled) bond targets of the dataset


   .. py:property:: bond_constraints
      :type: numpy.ndarray



   .. py:property:: atom_gt_mask
      :type: numpy.ndarray



   .. py:property:: atom_lt_mask
      :type: numpy.ndarray



   .. py:property:: bond_gt_mask
      :type: numpy.ndarray



   .. py:property:: bond_lt_mask
      :type: numpy.ndarray



   .. py:property:: E_ds
      :type: list[numpy.ndarray]


      the (scaled) bond descriptors of the dataset


   .. py:property:: d_ed
      :type: int


      the extra bond descriptor dimension, if any


   .. py:method:: normalize_targets(key = 'mol', scaler = None)

      Normalizes the targets of this dataset using a :obj:`StandardScaler`

      The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for
      each task independently. NOTE: This should only be used for regression datasets.

      :returns: a scaler fit to the targets.
      :rtype: StandardScaler



   .. py:method:: normalize_inputs(key = 'X_d', scaler = None)


   .. py:method:: reset()

      Reset the atom and bond features; atom and extra descriptors; and targets of each
      datapoint to their initial, unnormalized values.



.. py:class:: MolAtomBondDatum

   Bases: :py:obj:`NamedTuple`


   a singular training data point that supports atom and bond level targets


   .. py:attribute:: mg
      :type:  chemprop.data.molgraph.MolGraph


   .. py:attribute:: V_d
      :type:  numpy.ndarray | None


   .. py:attribute:: E_d
      :type:  numpy.ndarray | None


   .. py:attribute:: x_d
      :type:  numpy.ndarray | None


   .. py:attribute:: ys
      :type:  tuple[numpy.ndarray | None, numpy.ndarray | None, numpy.ndarray | None]


   .. py:attribute:: weight
      :type:  float


   .. py:attribute:: lt_masks
      :type:  tuple[numpy.ndarray | None, numpy.ndarray | None, numpy.ndarray | None]


   .. py:attribute:: gt_masks
      :type:  tuple[numpy.ndarray | None, numpy.ndarray | None, numpy.ndarray | None]


   .. py:attribute:: constraints
      :type:  tuple[numpy.ndarray | None, numpy.ndarray | None]


.. py:class:: MoleculeDataset

   Bases: :py:obj:`_MolGraphDatasetMixin`, :py:obj:`MolGraphDataset`


   A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s

   A :class:`MoleculeDataset` produces featurized data for input to a
   :class:`MPNN` model. Typically, data featurization is performed on-the-fly
   and parallelized across multiple workers via the :class:`~torch.utils.data
   DataLoader` class. However, for small datasets, it may be more efficient to
   featurize the data in advance and cache the results. This can be done by
   setting ``MoleculeDataset.cache=True``.

   :param data: the data from which to create a dataset
   :type data: Iterable[MoleculeDatapoint]
   :param featurizer: the featurizer with which to generate MolGraphs of the molecules
   :type featurizer: MoleculeFeaturizer
   :param n_workers: number of workers to use for cache calculation
   :type n_workers: int, optional


   .. py:attribute:: data
      :type:  list[chemprop.data.datapoints.MoleculeDatapoint]


   .. py:attribute:: featurizer
      :type:  chemprop.featurizers.base.Featurizer[rdkit.Chem.Mol, chemprop.data.molgraph.MolGraph]


   .. py:attribute:: n_workers
      :type:  int
      :value: 0



   .. py:method:: __post_init__()


   .. py:method:: __getitem__(idx)


   .. py:property:: cache
      :type: bool



   .. py:property:: smiles
      :type: list[str]


      the SMILES strings associated with the dataset


   .. py:property:: mols
      :type: list[rdkit.Chem.Mol]


      the molecules associated with the dataset


   .. py:property:: V_fs
      :type: list[numpy.ndarray]


      the (scaled) atom descriptors of the dataset


   .. py:property:: E_fs
      :type: list[numpy.ndarray]


      the (scaled) bond features of the dataset


   .. py:property:: V_ds
      :type: list[numpy.ndarray]


      the (scaled) atom descriptors of the dataset


   .. py:property:: d_vf
      :type: int


      the extra atom feature dimension, if any


   .. py:property:: d_ef
      :type: int


      the extra bond feature dimension, if any


   .. py:property:: d_vd
      :type: int


      the extra atom descriptor dimension, if any


   .. py:method:: normalize_inputs(key = 'X_d', scaler = None)


   .. py:method:: reset()

      Reset the atom and bond features; atom and extra descriptors; and targets of each
      datapoint to their initial, unnormalized values.



.. py:type:: MolGraphDataset
   :canonical: Dataset[Datum]


.. py:class:: MulticomponentDataset

   Bases: :py:obj:`_MolGraphDatasetMixin`, :py:obj:`torch.utils.data.Dataset`


   A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel
   :class:`MoleculeDatasets` and :class:`ReactionDataset`\s


   .. py:attribute:: datasets
      :type:  list[MoleculeDataset | ReactionDataset]

      the parallel datasets


   .. py:method:: __post_init__()


   .. py:method:: __len__()


   .. py:property:: n_components
      :type: int



   .. py:method:: __getitem__(idx)


   .. py:property:: smiles
      :type: list[list[str]]



   .. py:property:: names
      :type: list[list[str]]



   .. py:property:: mols
      :type: list[list[rdkit.Chem.Mol]]



   .. py:method:: normalize_targets(scaler = None)

      Normalizes the targets of this dataset using a :obj:`StandardScaler`

      The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for
      each task independently. NOTE: This should only be used for regression datasets.

      :returns: a scaler fit to the targets.
      :rtype: StandardScaler



   .. py:method:: normalize_inputs(key = 'X_d', scaler = None)


   .. py:method:: reset()

      Reset the atom and bond features; atom and extra descriptors; and targets of each
      datapoint to their initial, unnormalized values.



   .. py:property:: d_xd
      :type: list[int]


      the extra molecule descriptor dimension, if any


   .. py:property:: d_vf
      :type: list[int]



   .. py:property:: d_ef
      :type: list[int]



   .. py:property:: d_vd
      :type: list[int]



.. py:class:: ReactionDataset

   Bases: :py:obj:`_MolGraphDatasetMixin`, :py:obj:`MolGraphDataset`


   A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s

   .. note::
       The featurized data provided by this class may be cached, simlar to a
       :class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset
       cache=True``.


   .. py:attribute:: data
      :type:  list[chemprop.data.datapoints.ReactionDatapoint]

      the dataset from which to load


   .. py:attribute:: featurizer
      :type:  chemprop.featurizers.base.Featurizer[chemprop.types.Rxn, chemprop.data.molgraph.MolGraph]

      the featurizer with which to generate MolGraphs of the input


   .. py:attribute:: n_workers
      :type:  int
      :value: 0


      number of workers to use for cache calculation


   .. py:method:: __post_init__()


   .. py:property:: cache
      :type: bool



   .. py:method:: __getitem__(idx)


   .. py:property:: smiles
      :type: list[tuple]



   .. py:property:: mols
      :type: list[chemprop.types.Rxn]



   .. py:property:: d_vf
      :type: int



   .. py:property:: d_ef
      :type: int



   .. py:property:: d_vd
      :type: int



.. py:class:: MolGraph

   Bases: :py:obj:`NamedTuple`


   A :class:`MolGraph` represents the graph featurization of a molecule.


   .. py:attribute:: V
      :type:  numpy.ndarray

      an array of shape ``V x d_v`` containing the atom features of the molecule


   .. py:attribute:: E
      :type:  numpy.ndarray

      an array of shape ``E x d_e`` containing the bond features of the molecule


   .. py:attribute:: edge_index
      :type:  numpy.ndarray

      an array of shape ``2 x E`` containing the edges of the graph in COO format


   .. py:attribute:: rev_edge_index
      :type:  numpy.ndarray

      A array of shape ``E`` that maps from an edge index to the index of the source of the reverse edge in :attr:`edge_index` attribute.


.. py:class:: ClassBalanceSampler(Y, seed = None, shuffle = False)

   Bases: :py:obj:`torch.utils.data.Sampler`


   A :class:`ClassBalanceSampler` samples data from a :class:`MolGraphDataset` such that
   positive and negative classes are equally sampled

   :param dataset: the dataset from which to sample
   :type dataset: MolGraphDataset
   :param seed: the random seed to use for shuffling (only used when `shuffle` is `True`)
   :type seed: int
   :param shuffle: whether to shuffle the data during sampling
   :type shuffle: bool, default=False


   .. py:attribute:: shuffle
      :value: False



   .. py:attribute:: rg


   .. py:attribute:: pos_idxs


   .. py:attribute:: neg_idxs


   .. py:attribute:: length


   .. py:method:: __iter__()

      an iterator over indices to sample.



   .. py:method:: __len__()

      the number of indices that will be sampled.



.. py:class:: SeededSampler(N, seed)

   Bases: :py:obj:`torch.utils.data.Sampler`


   A :class`SeededSampler` is a class for iterating through a dataset in a randomly seeded
   fashion


   .. py:attribute:: idxs


   .. py:attribute:: rg


   .. py:method:: __iter__()

      an iterator over indices to sample.



   .. py:method:: __len__()

      the number of indices that will be sampled.



.. py:class:: SplitType

   Bases: :py:obj:`chemprop.utils.utils.EnumMapping`


   Enum where members are also (and must be) strings


   .. py:attribute:: SCAFFOLD_BALANCED


   .. py:attribute:: RANDOM_WITH_REPEATED_SMILES


   .. py:attribute:: RANDOM


   .. py:attribute:: KENNARD_STONE


   .. py:attribute:: KMEANS


.. py:function:: make_split_indices(mols, split = 'random', sizes = (0.8, 0.1, 0.1), seed = 0, num_replicates = 1, num_folds = None)

   Splits data into training, validation, and test splits.

   :param mols: Sequence of RDKit molecules to use for structure based splitting or any object with a length
                equal to the number of datapoints if using random splitting
   :type mols: Sequence[Chem.Mol] | Sized
   :param split: Split type, one of ~chemprop.data.utils.SplitType, by default "random"
   :type split: SplitType | str, optional
   :param sizes: 3-tuple with the proportions of data in the train, validation, and test sets, by default
                 (0.8, 0.1, 0.1). Set the middle value to 0 for a two way split.
   :type sizes: tuple[float, float, float], optional
   :param seed: The random seed passed to astartes, by default 0
   :type seed: int, optional
   :param num_replicates: Number of replicates, by default 1
   :type num_replicates: int, optional
   :param num_folds: This argument was removed in v2.1 - use `num_replicates` instead.
   :type num_folds: None, optional

   :returns: 2- or 3-member tuple containing num_replicates length lists of training, validation, and testing indexes.

             .. important::
                 Validation may or may not be present
   :rtype: tuple[list[list[int]], ...]

   :raises ValueError: Requested split sizes tuple not of length 3
   :raises ValueError: Unsupported split method requested


.. py:function:: split_data_by_indices(data, train_indices = None, val_indices = None, test_indices = None)

   Splits data into training, validation, and test groups based on split indices given.


