Training Regression - Reaction#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Import packages#
[2]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers, models, nn
Change data inputs here#
[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "rxn.csv"
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles'
target_columns = ['ea']
Load data#
[4]:
df_input = pd.read_csv(input_path)
df_input
[4]:
| smiles | ea | |
|---|---|---|
| 0 | [O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1... | 8.898934 |
| 1 | [C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:... | 5.464328 |
| 2 | [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... | 5.270552 |
| 3 | [C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])... | 8.473006 |
| 4 | [C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H... | 5.579037 |
| ... | ... | ... |
| 95 | [C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]... | 9.295665 |
| 96 | [O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11... | 7.753442 |
| 97 | [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... | 10.650215 |
| 98 | [C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4... | 10.138945 |
| 99 | [C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]... | 6.979934 |
100 rows Γ 2 columns
Load smiles and targets#
[5]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
smis[:5], ys[:5]
[5]:
(array(['[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:15])([H:13])[H:14])([H:11])[H:12])([H:9])[H:10])[H:8]>>[C:3](=[C:4]=[O:5])([H:11])[H:12].[C:6]([O:7][H:15])([H:8])([H:13])[H:14].[O:1]=[C:2]([H:9])[H:10]',
'[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:4]3([H:11])[O:5][C@:6]1([H:12])[C@@:7]23[H:13]>>[C:1]1([H:8])([H:9])[O:2][C:3]([H:10])=[C:7]([H:13])[C@:6]1([O+:5]=[C-:4][H:11])[H:12]',
'[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H:13])([H:14])[C:5]([H:15])=[C:6]([H:16])[C@@:7]12[H:17])([H:8])([H:9])[H:10]>>[C:1]([C@@:2]1([H:11])[C:3]([H:12])([H:13])[C:4]([H:14])=[C:5]([H:15])[C:6]([H:16])=[C:7]1[H:17])([H:8])([H:9])[H:10]',
'[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])([H:11])[H:12])([H:8])([H:9])[H:10]>>[C-:1]([O+:2]=[C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])[H:12])([H:8])[H:10].[H:9][H:11]',
'[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H:10])[H:11])([H:7])([H:8])[H:9]>>[C:1]([C:2](=[C:3]=[C:4]([H:10])[H:11])[C:5](=[O:6])[H:12])([H:7])([H:8])[H:9]'],
dtype=object),
array([[8.8989335 ],
[5.46432769],
[5.27055228],
[8.47300569],
[5.57903696]]))
Get datapoints#
[6]:
all_data = [data.ReactionDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
Perform data splitting for training, validation, and testing#
[7]:
mols = [d.rct for d in all_data] # Can either split by reactants (.rct) or products (.pdt)
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
Defining the featurizer#
Reactions can be featurized using the CondensedGraphOfReactionFeaturizer (also labeled CGRFeaturizer).
Use _mode keyword to set the mode by which a reaction should be featurized into a MolGraph.
Options are can be found with featurizers.RxnMode.keys
[8]:
for key in featurizers.RxnMode.keys():
print(key)
REAC_PROD
REAC_PROD_BALANCE
REAC_DIFF
REAC_DIFF_BALANCE
PROD_DIFF
PROD_DIFF_BALANCE
[9]:
featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_="PROD_DIFF")
Get ReactionDatasets#
[10]:
train_dset = data.ReactionDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()
val_dset = data.ReactionDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_data[0], featurizer)
Get dataloaders#
[11]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
Change Message-Passing Neural Network (MPNN) inputs here#
Message passing#
Message passing blocks must be given the shape of the featurizerβs outputs.
Options are mp = nn.BondMessagePassing() or mp = nn.AtomMessagePassing()
[12]:
fdims = featurizer.shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing(*fdims)
Aggregation#
[13]:
print(nn.agg.AggregationRegistry)
ClassRegistry {
'mean': <class 'chemprop.nn.agg.MeanAggregation'>,
'sum': <class 'chemprop.nn.agg.SumAggregation'>,
'norm': <class 'chemprop.nn.agg.NormAggregation'>
}
[14]:
agg = nn.MeanAggregation()
Feed-Forward Network (FFN)#
[15]:
print(nn.PredictorRegistry)
ClassRegistry {
'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,
'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}
[16]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
[17]:
ffn = nn.RegressionFFN(output_transform=output_transform)
Batch norm#
[18]:
batch_norm = True
Metrics#
[19]:
print(nn.metrics.MetricRegistry)
ClassRegistry {
'mse': <class 'chemprop.nn.metrics.MSE'>,
'mae': <class 'chemprop.nn.metrics.MAE'>,
'rmse': <class 'chemprop.nn.metrics.RMSE'>,
'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
'r2': <class 'chemprop.nn.metrics.R2Score'>,
'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}
[20]:
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]
# Only the first metric is used for training and early stopping
Construct MPNN#
[21]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn
[21]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=134, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=406, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0): RMSE(task_weights=[[1.0]])
(1): MAE(task_weights=[[1.0]])
(2): MSE(task_weights=[[1.0]])
)
)
Training and testing#
Set up trainer#
[22]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20, # number of epochs to train for
)
π‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Start training#
[23]:
trainer.fit(mpnn, train_loader, val_loader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
βββββ³ββββββββββββββββββ³βββββββββββββββββββββ³βββββββββ³ββββββββ³ββββββββ β β Name β Type β Params β Mode β FLOPs β β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β 0 β message_passing β BondMessagePassing β 252 K β train β 0 β β 1 β agg β MeanAggregation β 0 β train β 0 β β 2 β bn β BatchNorm1d β 600 β train β 0 β β 3 β predictor β RegressionFFN β 90.6 K β train β 0 β β 4 β X_d_transform β Identity β 0 β train β 0 β β 5 β metrics β ModuleList β 0 β train β 0 β βββββ΄ββββββββββββββββββ΄βββββββββββββββββββββ΄βββββββββ΄ββββββββ΄ββββββββ
Trainable params: 343 K Non-trainable params: 0 Total params: 343 K Total estimated model params size (MB): 1 Modules in train mode: 25 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.
Test results#
[24]:
results = trainer.test(mpnn, test_loader, weights_only=False) # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
βββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ β Test metric β DataLoader 0 β β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β test/mae β 1.133582592010498 β β test/rmse β 1.4866628646850586 β βββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.