Training Regression - Reaction#

Import packages#

[16]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path

from chemprop import data, featurizers, models, nn

Change data inputs here#

[17]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "rxn.csv"
num_workers = 0  # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles'
target_columns = ['ea']

Load data#

[18]:
df_input = pd.read_csv(input_path)
df_input
[18]:
smiles ea
0 [O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1... 8.898934
1 [C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:... 5.464328
2 [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... 5.270552
3 [C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])... 8.473006
4 [C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H... 5.579037
... ... ...
95 [C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]... 9.295665
96 [O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11... 7.753442
97 [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... 10.650215
98 [C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4... 10.138945
99 [C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]... 6.979934

100 rows × 2 columns

Load smiles and targets#

[19]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

smis[:5], ys[:5]
[19]:
(array(['[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:15])([H:13])[H:14])([H:11])[H:12])([H:9])[H:10])[H:8]>>[C:3](=[C:4]=[O:5])([H:11])[H:12].[C:6]([O:7][H:15])([H:8])([H:13])[H:14].[O:1]=[C:2]([H:9])[H:10]',
        '[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:4]3([H:11])[O:5][C@:6]1([H:12])[C@@:7]23[H:13]>>[C:1]1([H:8])([H:9])[O:2][C:3]([H:10])=[C:7]([H:13])[C@:6]1([O+:5]=[C-:4][H:11])[H:12]',
        '[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H:13])([H:14])[C:5]([H:15])=[C:6]([H:16])[C@@:7]12[H:17])([H:8])([H:9])[H:10]>>[C:1]([C@@:2]1([H:11])[C:3]([H:12])([H:13])[C:4]([H:14])=[C:5]([H:15])[C:6]([H:16])=[C:7]1[H:17])([H:8])([H:9])[H:10]',
        '[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])([H:11])[H:12])([H:8])([H:9])[H:10]>>[C-:1]([O+:2]=[C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])[H:12])([H:8])[H:10].[H:9][H:11]',
        '[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H:10])[H:11])([H:7])([H:8])[H:9]>>[C:1]([C:2](=[C:3]=[C:4]([H:10])[H:11])[C:5](=[O:6])[H:12])([H:7])([H:8])[H:9]'],
       dtype=object),
 array([[8.8989335 ],
        [5.46432769],
        [5.27055228],
        [8.47300569],
        [5.57903696]]))

Get datapoints#

[20]:
all_data = [data.ReactionDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Perform data splitting for training, validation, and testing#

[21]:
mols = [d.rct for d in all_data]  # Can either split by reactants (.rct) or products (.pdt)
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

Defining the featurizer#

Reactions can be featurized using the CondensedGraphOfReactionFeaturizer (also labeled CGRFeaturizer).

Use _mode keyword to set the mode by which a reaction should be featurized into a MolGraph.

Options are can be found with featurizers.RxnMode.keys

[22]:
for key in featurizers.RxnMode.keys():
    print(key)
REAC_PROD
REAC_PROD_BALANCE
REAC_DIFF
REAC_DIFF_BALANCE
PROD_DIFF
PROD_DIFF_BALANCE
[23]:
featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_="PROD_DIFF")

Get ReactionDatasets#

[24]:
train_dset = data.ReactionDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.ReactionDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_data, featurizer)

Get dataloaders#

[25]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

Change Message-Passing Neural Network (MPNN) inputs here#

Message passing#

Message passing blocks must be given the shape of the featurizer’s outputs.

Options are mp = nn.BondMessagePassing() or mp = nn.AtomMessagePassing()

[26]:
fdims = featurizer.shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing(*fdims)

Aggregation#

[27]:
print(nn.agg.AggregationRegistry)
ClassRegistry {
    'mean': <class 'chemprop.nn.agg.MeanAggregation'>,
    'sum': <class 'chemprop.nn.agg.SumAggregation'>,
    'norm': <class 'chemprop.nn.agg.NormAggregation'>
}
[28]:
agg = nn.MeanAggregation()

Feed-Forward Network (FFN)#

[29]:
print(nn.PredictorRegistry)
ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}
[30]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer("mean", torch.tensor(mean, dtype=torch.float).unsqueeze(0))
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:22: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer("scale", torch.tensor(scale, dtype=torch.float).unsqueeze(0))
[31]:
ffn = nn.RegressionFFN(output_transform=output_transform)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.

Batch norm#

[32]:
batch_norm = True

Metrics#

[33]:
print(nn.metrics.MetricRegistry)
ClassRegistry {
    'mae': <class 'chemprop.nn.metrics.MAEMetric'>,
    'mse': <class 'chemprop.nn.metrics.MSEMetric'>,
    'rmse': <class 'chemprop.nn.metrics.RMSEMetric'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAEMetric'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSEMetric'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSEMetric'>,
    'r2': <class 'chemprop.nn.metrics.R2Metric'>,
    'roc': <class 'chemprop.nn.metrics.AUROCMetric'>,
    'prc': <class 'chemprop.nn.metrics.AUPRCMetric'>,
    'accuracy': <class 'chemprop.nn.metrics.AccuracyMetric'>,
    'f1': <class 'chemprop.nn.metrics.F1Metric'>,
    'bce': <class 'chemprop.nn.metrics.BCEMetric'>,
    'ce': <class 'chemprop.nn.metrics.CrossEntropyMetric'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'sid': <class 'chemprop.nn.metrics.SIDMetric'>,
    'wasserstein': <class 'chemprop.nn.metrics.WassersteinMetric'>
}
[34]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()]
# Only the first metric is used for training and early stopping

Construct MPNN#

[35]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn
[35]:
MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=134, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=406, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss()
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
)

Training and testing#

Set up trainer#

[36]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,  # number of epochs to train for
)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Start training#

[37]:
trainer.fit(mpnn, train_loader, val_loader)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/hwpang/Projects/chemprop_v2_dev/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 252 K
1 | agg             | MeanAggregation    | 0
2 | bn              | BatchNorm1d        | 600
3 | predictor       | RegressionFFN      | 90.6 K
4 | X_d_transform   | Identity           | 0
  | other params    | n/a                | 1
-------------------------------------------------------
343 K     Trainable params
1         Non-trainable params
343 K     Total params
1.374     Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 16.36it/s, train_loss=0.797, val_loss=8.790]Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 43.65it/s, train_loss=0.0333, val_loss=8.480]
`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 37.26it/s, train_loss=0.0333, val_loss=8.480]

Test results#

[38]:
results = trainer.test(mpnn, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 237.23it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/mae            1.0371202230453491
        test/rmse           1.3453567028045654
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────