Dataloaders#

[1]:
from chemprop.data.dataloader import build_dataloader
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/cuik_molmaker/mol_features.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources

This is an example dataset to load.

[2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset

smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])

Torch dataloaders#

Chemprop uses native torch.utils.data.Dataloaders to batch data as input to a model. build_dataloader is a helper function to make the dataloader.

[3]:
dataloader = build_dataloader(dataset)

build_dataloader changes the defaults of Dataloader to use a batch size of 64 and turn on shuffling. It also automatically uses the correct collating function for the dataset (single component vs multi-component)

[4]:
from torch.utils.data import DataLoader
from chemprop.data.collate import collate_batch, collate_multicomponent

dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)

Collate function#

The collate function takes an iterable of dataset outputs and batches them together. Iterating through batches is done automatically during training by the lightning Trainer.

[5]:
collate_batch([dataset[0], dataset[1]])
[5]:
TrainingBatch(bmg=<chemprop.data.collate.BatchMolGraph object at 0x77d2af901710>, V_d=None, X_d=None, Y=tensor([[0.2622],
        [0.0477]]), w=tensor([[1.],
        [1.]]), lt_mask=None, gt_mask=None)

Shuffling#

Shuffling the data helps improve model training, so build_dataloader has shuffle=True as the default. Shuffling should be turned off for validation and test dataloaders. Lightning gives a warning if a dataloader with shuffling is used during prediction.

[6]:
train_loader = build_dataloader(dataset)
val_loader = build_dataloader(dataset, shuffle=False)
test_loader = build_dataloader(dataset, shuffle=False)
[7]:
from lightning import pytorch as pl
from chemprop import models, nn

trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
[8]:
preds = trainer.predict(chemprop_model, dataloader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:485: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install
"ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')
[9]:
preds = trainer.predict(chemprop_model, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Parallel data loading#

As datapoints are sampled from the dataset, the MolGraph data structures are generated on-the-fly, which requires featurization of the molecular graphs. Giving the dataloader multiple workers can increase dataloading speed by preparing the datapoints in parallel. Note that this is not compatible with Windows (the process hangs) and some versions of Mac.

build_dataloader(dataset, num_workers=8)

Caching the the MolGraphs in the dataset before making the dataloader can also speed up sequential dataloading (num_workers=0).

[10]:
dataset.cache = True
build_dataloader(dataset)
[10]:
<torch.utils.data.dataloader.DataLoader at 0x77d3497d9790>

Drop last batch#

build_dataloader drops the last batch if it is a single datapoint as batch normalization (the default) requires at least two data points. If you do not want to drop the last datapoint, you can adjust the batch size, or, if you aren’t using batch normalization, build the dataloader manually.

[11]:
dataloader = build_dataloader(dataset, batch_size=2)
Dropping last batch of size 1 to avoid issues with batch normalization     (dataset size = 3, batch_size = 2)
[12]:
dataloader = build_dataloader(dataset, batch_size=3)
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)

Samplers#

The default sampler for a torch.utils.data.Dataloader is a torch.utils.data.sampler.SequentialSampler for shuffle=False, or a torch.utils.data.sampler.RandomSampler if shuffle=True.

build_dataloader can be given a seed to make a chemprop.data.samplers.SeededSampler for reproducibility. Chemprop also offers chemprop.data.samplers.ClassSampler to equally sample positive and negative classes for binary classification tasks.

[13]:
build_dataloader(dataset, seed=0)
[13]:
<torch.utils.data.dataloader.DataLoader at 0x77d1a3202ba0>
[14]:
smis = ["C" * i for i in range(1, 11)]
ys = np.random.randint(low=0, high=2, size=(len(smis), 1))
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])

dataloader = build_dataloader(dataset, class_balance=True)

_, _, _, Y, *_ = next(iter(dataloader))
print(Y)
tensor([[1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.]])