Source code for chemprop.data.dataloader
import logging
from torch.utils.data import DataLoader
from chemprop.data.collate import (
collate_batch,
collate_cuik_batch,
collate_mol_atom_bond_batch,
collate_multicomponent,
)
from chemprop.data.datasets import (
CuikmolmakerDataset,
MolAtomBondDataset,
MoleculeDataset,
MulticomponentDataset,
ReactionDataset,
)
from chemprop.data.samplers import ClassBalanceSampler, SeededSampler
logger = logging.getLogger(__name__)
[docs]
def build_dataloader(
dataset: MoleculeDataset
| CuikmolmakerDataset
| MolAtomBondDataset
| ReactionDataset
| MulticomponentDataset,
batch_size: int = 64,
num_workers: int = 0,
class_balance: bool = False,
seed: int | None = None,
shuffle: bool = True,
drop_last: bool | None = None,
**kwargs,
):
r"""Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s
Parameters
----------
dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset
The dataset containing the molecules or reactions to load.
batch_size : int, default=64
the batch size to load.
num_workers : int, default=0
the number of workers used to build batches.
class_balance : bool, default=False
Whether to perform class balancing (i.e., use an equal number of positive and negative
molecules). Class balance is only available for single task classification datasets. Set
shuffle to True in order to get a random subset of the larger class.
seed : int, default=None
the random seed to use for shuffling (only used when `shuffle` is `True`).
shuffle : bool, default=True
whether to shuffle the data during sampling.
drop_last : bool, default=None
Whether to drop the last batch if it is of size 1 (needed if using batchnorm during training).
If None, this will be set automatically.
"""
if class_balance:
sampler = ClassBalanceSampler(dataset.Y, seed, shuffle)
elif shuffle and seed is not None:
sampler = SeededSampler(len(dataset), seed)
else:
sampler = None
if isinstance(dataset, MulticomponentDataset):
collate_fn = collate_multicomponent
elif isinstance(dataset, CuikmolmakerDataset):
collate_fn = collate_cuik_batch
elif isinstance(dataset, MolAtomBondDataset):
collate_fn = collate_mol_atom_bond_batch
else:
collate_fn = collate_batch
if drop_last is None:
if len(dataset) % batch_size == 1:
logger.warning(
f"Dropping last batch of size 1 to avoid issues with batch normalization \
(dataset size = {len(dataset)}, batch_size = {batch_size})"
)
drop_last = True
else:
drop_last = False
return DataLoader(
dataset,
batch_size,
sampler is None and shuffle,
sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
**kwargs,
)