Training#
Import packages#
[19]:
import pandas as pd
from pathlib import Path
from lightning import pytorch as pl
from chemprop import data, featurizers, models, nn
Change data inputs here#
[20]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets
Load data#
[21]:
df_input = pd.read_csv(input_path)
df_input
[21]:
smiles | lipo | |
---|---|---|
0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
... | ... | ... |
95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows × 2 columns
Get SMILES and targets#
[22]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
[23]:
smis[:5] # show first 5 SMILES strings
[23]:
array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
dtype=object)
[24]:
ys[:5] # show first 5 targets
[24]:
array([[ 3.54],
[-1.18],
[ 3.69],
[ 3.37],
[ 3.1 ]])
Get molecule datapoints#
[25]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
Perform data splitting for training, validation, and testing#
[26]:
# available split types
list(data.SplitType.keys())
[26]:
['CV_NO_VAL',
'CV',
'SCAFFOLD_BALANCED',
'RANDOM_WITH_REPEATED_SMILES',
'RANDOM',
'KENNARD_STONE',
'KMEANS']
[27]:
mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
Get MoleculeDataset#
[28]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()
val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data, featurizer)
Get DataLoader#
[29]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
Change Message-Passing Neural Network (MPNN) inputs here#
Message Passing#
A Message passing
constructs molecular graphs using message passing to learn node-level hidden representations.
Options are mp = nn.BondMessagePassing()
or mp = nn.AtomMessagePassing()
[30]:
mp = nn.BondMessagePassing()
Aggregation#
An Aggregation
is responsible for constructing a graph-level representation from the set of node-level representations after message passing.
Available options can be found in nn.agg.AggregationRegistry
, including - agg = nn.MeanAggregation()
- agg = nn.SumAggregation()
- agg = nn.NormAggregation()
[31]:
print(nn.agg.AggregationRegistry)
ClassRegistry {
'mean': <class 'chemprop.nn.agg.MeanAggregation'>,
'sum': <class 'chemprop.nn.agg.SumAggregation'>,
'norm': <class 'chemprop.nn.agg.NormAggregation'>
}
[32]:
agg = nn.MeanAggregation()
Feed-Forward Network (FFN)#
A FFN
takes the aggregated representations and make target predictions.
Available options can be found in nn.PredictorRegistry
.
For regression: - ffn = nn.RegressionFFN()
- ffn = nn.MveFFN()
- ffn = nn.EvidentialFFN()
For classification: - ffn = nn.BinaryClassificationFFN()
- ffn = nn.BinaryDirichletFFN()
- ffn = nn.MulticlassClassificationFFN()
- ffn = nn.MulticlassDirichletFFN()
For spectral: - ffn = nn.SpectralFFN()
# will be available in future version
[33]:
print(nn.PredictorRegistry)
ClassRegistry {
'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}
[34]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.register_buffer("mean", torch.tensor(mean, dtype=torch.float).unsqueeze(0))
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:22: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.register_buffer("scale", torch.tensor(scale, dtype=torch.float).unsqueeze(0))
[35]:
ffn = nn.RegressionFFN(output_transform=output_transform)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.
Batch Norm#
A Batch Norm
normalizes the outputs of the aggregation by re-centering and re-scaling.
Whether to use batch norm
[36]:
batch_norm = True
Metrics#
Metrics
are the ways to evaluate the performance of model predictions.
Available options can be found in metrics.MetricRegistry
, including
[37]:
print(nn.metrics.MetricRegistry)
ClassRegistry {
'mae': <class 'chemprop.nn.metrics.MAEMetric'>,
'mse': <class 'chemprop.nn.metrics.MSEMetric'>,
'rmse': <class 'chemprop.nn.metrics.RMSEMetric'>,
'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAEMetric'>,
'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSEMetric'>,
'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSEMetric'>,
'r2': <class 'chemprop.nn.metrics.R2Metric'>,
'roc': <class 'chemprop.nn.metrics.AUROCMetric'>,
'prc': <class 'chemprop.nn.metrics.AUPRCMetric'>,
'accuracy': <class 'chemprop.nn.metrics.AccuracyMetric'>,
'f1': <class 'chemprop.nn.metrics.F1Metric'>,
'bce': <class 'chemprop.nn.metrics.BCEMetric'>,
'ce': <class 'chemprop.nn.metrics.CrossEntropyMetric'>,
'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
'sid': <class 'chemprop.nn.metrics.SIDMetric'>,
'wasserstein': <class 'chemprop.nn.metrics.WassersteinMetric'>
}
[38]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] # Only the first metric is used for training and early stopping
Constructs MPNN#
[39]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn
[39]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSELoss()
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
)
Set up trainer#
[40]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20, # number of epochs to train for
)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Start training#
[41]:
trainer.fit(mpnn, train_loader, val_loader)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/hwpang/Projects/chemprop_v2_dev/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
| Name | Type | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K
1 | agg | MeanAggregation | 0
2 | bn | BatchNorm1d | 600
3 | predictor | RegressionFFN | 90.6 K
4 | X_d_transform | Identity | 0
| other params | n/a | 1
-------------------------------------------------------
318 K Trainable params
1 Non-trainable params
318 K Total params
1.276 Total estimated model params size (MB)
Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 33.35it/s, train_loss=0.026, val_loss=2.160]
`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 29.84it/s, train_loss=0.026, val_loss=2.160]
Test results#
[42]:
results = trainer.test(mpnn, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 238.52it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test/mae 0.6950093507766724
test/rmse 0.9511734247207642
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ ]: