Source code for chemprop.data.dataloader

import logging

from torch.utils.data import DataLoader

from chemprop.data.collate import (
    collate_batch,
    collate_cuik_batch,
    collate_mol_atom_bond_batch,
    collate_multicomponent,
)
from chemprop.data.datasets import (
    CuikmolmakerDataset,
    MolAtomBondDataset,
    MoleculeDataset,
    MulticomponentDataset,
    ReactionDataset,
)
from chemprop.data.samplers import ClassBalanceSampler, SeededSampler

logger = logging.getLogger(__name__)


[docs] def build_dataloader( dataset: MoleculeDataset | CuikmolmakerDataset | MolAtomBondDataset | ReactionDataset | MulticomponentDataset, batch_size: int = 64, num_workers: int = 0, class_balance: bool = False, seed: int | None = None, shuffle: bool = True, drop_last: bool | None = None, **kwargs, ): r"""Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s Parameters ---------- dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset The dataset containing the molecules or reactions to load. batch_size : int, default=64 the batch size to load. num_workers : int, default=0 the number of workers used to build batches. class_balance : bool, default=False Whether to perform class balancing (i.e., use an equal number of positive and negative molecules). Class balance is only available for single task classification datasets. Set shuffle to True in order to get a random subset of the larger class. seed : int, default=None the random seed to use for shuffling (only used when `shuffle` is `True`). shuffle : bool, default=True whether to shuffle the data during sampling. drop_last : bool, default=None Whether to drop the last batch if it is of size 1 (needed if using batchnorm during training). If None, this will be set automatically. """ if class_balance: sampler = ClassBalanceSampler(dataset.Y, seed, shuffle) elif shuffle and seed is not None: sampler = SeededSampler(len(dataset), seed) else: sampler = None if isinstance(dataset, MulticomponentDataset): collate_fn = collate_multicomponent elif isinstance(dataset, CuikmolmakerDataset): collate_fn = collate_cuik_batch elif isinstance(dataset, MolAtomBondDataset): collate_fn = collate_mol_atom_bond_batch else: collate_fn = collate_batch if drop_last is None: if len(dataset) % batch_size == 1: logger.warning( f"Dropping last batch of size 1 to avoid issues with batch normalization \ (dataset size = {len(dataset)}, batch_size = {batch_size})" ) drop_last = True else: drop_last = False return DataLoader( dataset, batch_size, sampler is None and shuffle, sampler, num_workers=num_workers, collate_fn=collate_fn, drop_last=drop_last, **kwargs, )