Data splitting#

[1]:
from chemprop.data import SplitType, make_split_indices, split_data_by_indices

These are example datapoints to split.

[2]:
import numpy as np

from chemprop.data import MoleculeDatapoint

smis = ["C" * i for i in range(1, 11)]
ys = np.random.rand(len(smis), 1)
datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Data splits#

A typical Chemprop workflow uses three sets of data. The first is used to train the model. The second is used as validation for early stopping and hyperparameter optimization. The third is used to test the final model’s performance as an estimate for how it will perform on future data.

Chemprop provides helper functions to split data into these training, validation, and test sets. Available splitting schemes are listed in SplitType. All of these rely on `astartes <JacksonBurns/astartes>`__ in the backend.

[3]:
for splittype in SplitType:
    print(splittype)
scaffold_balanced
random_with_repeated_smiles
random
kennard_stone
kmeans

Splitting steps#

  1. Collect the rdkit.Chem.mol objects for each datapoint. These are required for structure based splits.

  2. Generate the splitting indices.

  3. Split the data using those indices.

The make_split_indices function includes a num_replicates argument to perform repeated splits (each with a different random seed) with your sampler of choice. Any sampler can be used for replicates, though deterministic samplers (i.e. Kennard-Stone) will not change on replicates. Splits are returned as a 2- or 3-member tuple containing num_replicates-length lists of training, validation, and testing indexes.

[4]:
mols = [d.mol for d in datapoints]

train_indices, val_indices, test_indices = make_split_indices(mols)

train_data, val_data, test_data = split_data_by_indices(
    datapoints, train_indices, val_indices, test_indices
)

The default splitting scheme is a random split with 80% of the data used to train, 10% to validate and 10% to split.

[5]:
len(train_data), len(val_data), len(test_data)
[5]:
(1, 1, 1)

Each of these is length 1 because we only requested 1 replicate (the default). The inner lists for each of these sets contain the actual indices for training.

[6]:
len(train_data[0]), len(val_data[0]), len(test_data[0])
[6]:
(8, 1, 1)

Split randomness#

All split randomness uses a default seed of 0 and numpy.random. The seed can be changed to get different splits.

[7]:
make_split_indices(datapoints)
[7]:
([[8, 4, 9, 1, 6, 7, 3, 0]], [[5]], [[2]])
[8]:
make_split_indices(datapoints, seed=12)
[8]:
([[8, 7, 0, 4, 9, 3, 2, 1]], [[6]], [[5]])

Split fractions#

The split sizes can also be changed. Set the middle value to 0 for a two way split. If the data can not be split to exactly the specified proportions, you will get a warning from astartes with the actual sizes used. And if the specified sizes don’t sum to 1, the sizes will first be rescaled to sum to 1.

[9]:
make_split_indices(datapoints, sizes=(0.4, 0.3, 0.3))
[9]:
([[8, 4, 9, 1]], [[6, 7, 3]], [[0, 5, 2]])
[10]:
make_split_indices(datapoints, sizes=(0.6, 0.0, 0.4))
[10]:
([[8, 4, 9, 1, 6, 7]], [[]], [[3, 0, 5, 2]])
[11]:
make_split_indices(datapoints, sizes=(0.5, 0.25, 0.25))
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested validation size of 0.25, got 0.30. Requested test size of 0.25, got 0.30.
  warn(
[11]:
([[8, 4, 9, 1, 6]], [[7, 3]], [[0, 5, 2]])
[12]:
make_split_indices(datapoints, sizes=(0.5, 0.5, 0.5))
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:381: NormalizationWarning: Requested train/val/test split (0.50, 0.50, 0.50) do not sum to 1.0, normalizing to train=0.33, val=0.33, test=0.33.
  warn(
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested train size of 0.33, got 0.30. Requested test size of 0.33, got 0.20.
  warn(
[12]:
([[8, 4, 9]], [[1, 6, 7, 3, 0]], [[5, 2]])

Random with repeated molecules#

If your dataset has repeated molecules, all duplicate molecules should go in the same split. This split type requires the rdkit.Chem.mol objects of the datapoints. It first removes duplicates before using astartes to make the random splits and then adds back in the duplicate datapoints.

[13]:
smis = ["O", "O"] + ["C" * i for i in range(1, 10)]
ys = np.random.rand(len(smis), 1)
repeat_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in repeat_datapoints]
[14]:
make_split_indices(mols, split="random_with_repeated_smiles")
[14]:
([[10, 6, 0, 1, 3, 8, 9, 5, 2]], [[7]], [[4]])

Structure based splits#

Including all similar molecules in only one of the datasets can give a more realistic estimate of how a model will perform on unseen chemistry. This uses the rdkit.Chem.mol representation of the molecules. See the astartes documentation for details about Kennard Stone, k-means, and scaffold balanced splitting schemes.

[15]:
smis = [
    "Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14",
    "COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23",
    "COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl",
    "OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3",
    "Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1",
    "OC1(CN2CCC1CC2)C#Cc3ccc(cc3)c4ccccc4",
    "COc1cc(OC)c(cc1NC(=O)CCC(=O)O)S(=O)(=O)NCc2ccccc2N3CCCCC3",
    "CNc1cccc(CCOc2ccc(C[C@H](NC(=O)c3c(Cl)cccc3Cl)C(=O)O)cc2C)n1",
    "COc1ccc(cc1)C2=COc3cc(OC)cc(OC)c3C2=O",
    "Oc1ncnc2scc(c3ccsc3)c12",
    "CS(=O)(=O)c1ccc(Oc2ccc(cc2)C#C[C@]3(O)CN4CCC3CC4)cc1",
    "C[C@H](Nc1nc(Nc2cc(C)[nH]n2)c(C)nc1C#N)c3ccc(F)cn3",
    "O=C1CCCCCN1",
    "CCCSc1ncccc1C(=O)N2CCCC2c3ccncc3",
    "CC1CCCCC1NC(=O)c2cnn(c2NS(=O)(=O)c3ccc(C)cc3)c4ccccc4",
    "Nc1ccc(cc1)c2nc3ccc(O)cc3s2",
    "COc1ccc(cc1)N2CCN(CC2)C(=O)[C@@H]3CCCC[C@H]3C(=O)NCC#N",
    "CCC(COC(=O)c1cc(OC)c(OC)c(OC)c1)(N(C)C)c2ccccc2",
    "COc1cc(ccc1N2CC[C@@H](O)C2)N3N=Nc4cc(sc4C3=O)c5ccc(Cl)cc5",
    "CO[C@H]1CN(CCN2C(=O)C=Cc3ccc(cc23)C#N)CC[C@H]1NCc4ccc5OCC(=O)Nc5n4",
]

ys = np.random.rand(len(smis), 1)
datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in datapoints]
[16]:
make_split_indices(mols, split="kmeans")
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested train size of 0.80, got 0.85. Requested test size of 0.10, got 0.05.
  warn(
[16]:
([[0, 1, 2, 3, 4, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19]],
 [[5, 10]],
 [[7]])