Source code for chemprop.data.samplers
from itertools import chain
from typing import Iterator, Optional
import numpy as np
from torch.utils.data import Sampler
[docs]
class SeededSampler(Sampler):
"""A :class`SeededSampler` is a class for iterating through a dataset in a randomly seeded
fashion"""
def __init__(self, N: int, seed: int):
if seed is None:
raise ValueError("arg 'seed' was `None`! A SeededSampler must be seeded!")
self.idxs = np.arange(N)
self.rg = np.random.default_rng(seed)
[docs]
def __iter__(self) -> Iterator[int]:
"""an iterator over indices to sample."""
self.rg.shuffle(self.idxs)
return iter(self.idxs)
[docs]
def __len__(self) -> int:
"""the number of indices that will be sampled."""
return len(self.idxs)
[docs]
class ClassBalanceSampler(Sampler):
"""A :class:`ClassBalanceSampler` samples data from a :class:`MolGraphDataset` such that
positive and negative classes are equally sampled
Parameters
----------
dataset : MolGraphDataset
the dataset from which to sample
seed : int
the random seed to use for shuffling (only used when `shuffle` is `True`)
shuffle : bool, default=False
whether to shuffle the data during sampling
"""
def __init__(self, Y: np.ndarray, seed: Optional[int] = None, shuffle: bool = False):
self.shuffle = shuffle
self.rg = np.random.default_rng(seed)
idxs = np.arange(len(Y))
actives = Y.any(1)
self.pos_idxs = idxs[actives]
self.neg_idxs = idxs[~actives]
self.length = 2 * min(len(self.pos_idxs), len(self.neg_idxs))
[docs]
def __iter__(self) -> Iterator[int]:
"""an iterator over indices to sample."""
if self.shuffle:
self.rg.shuffle(self.pos_idxs)
self.rg.shuffle(self.neg_idxs)
return chain(*zip(self.pos_idxs, self.neg_idxs))
[docs]
def __len__(self) -> int:
"""the number of indices that will be sampled."""
return self.length