Training Regression - Multicomponent#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
[2]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers, models, nn
from chemprop.nn import metrics
from chemprop.models import multi
Load data#
Change your data inputs here#
[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv" # path to your data .csv file containing SMILES strings and target values
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
target_columns = ['peakwavs_max'] # list of names of the columns containing targets
Read data#
[4]:
df_input = pd.read_csv(input_path)
df_input
[4]:
| smiles | solvent | peakwavs_max | |
|---|---|---|---|
| 0 | CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | ClCCl | 642.0 |
| 1 | C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | ClCCl | 420.0 |
| 2 | CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | O | 544.0 |
| 3 | c1ccc2[nH]ccc2c1 | O | 290.0 |
| 4 | CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | ClC(Cl)Cl | 736.0 |
| ... | ... | ... | ... |
| 95 | COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | C1CCOC1 | 359.0 |
| 96 | COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | C1CCCCC1 | 386.0 |
| 97 | CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | CCO | 425.0 |
| 98 | Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | c1ccccc1 | 324.0 |
| 99 | Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | ClCCl | 391.0 |
100 rows ร 3 columns
Get SMILES and targets#
[5]:
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values
[6]:
# Take a look at the first 5 SMILES strings and targets
smiss[:5], ys[:5]
[6]:
(array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
'ClCCl'],
['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
'ClCCl'],
['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
'O'],
['c1ccc2[nH]ccc2c1', 'O'],
['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
'ClC(Cl)Cl']], dtype=object),
array([[642.],
[420.],
[544.],
[290.],
[736.]]))
Make molecule datapoints#
Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component.
[7]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]
Split data#
Perform data splitting for training, validation, and testing#
[8]:
component_to_split_by = 0 # index of the component to use for structure based splits
mols = [d.mol for d in all_data[component_to_split_by]]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
Get MoleculeDataset for each components#
[9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_datasets = [data.MoleculeDataset(train_data[0][i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[0][i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[0][i], featurizer) for i in range(len(smiles_columns))]
Construct multicomponent dataset and scale the targets#
[10]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)
Construct data loader#
[11]:
train_loader = data.build_dataloader(train_mcdset)
val_loader = data.build_dataloader(val_mcdset, shuffle=False)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)
Construct multicomponent MPNN#
MulticomponentMessagePassing#
blocks: a list of message passing block used for each componentsn_components: number of components
[12]:
mcmp = nn.MulticomponentMessagePassing(
blocks=[nn.BondMessagePassing() for _ in range(len(smiles_columns))],
n_components=len(smiles_columns),
)
Aggregation#
[13]:
agg = nn.MeanAggregation()
RegressionFFN#
[14]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
[15]:
ffn = nn.RegressionFFN(
input_dim=mcmp.output_dim,
output_transform=output_transform,
)
Metrics#
[16]:
metric_list = [metrics.RMSE(), metrics.MAE()] # Only the first metric is used for training and early stopping
MulticomponentMPNN#
[17]:
mcmpnn = multi.MulticomponentMPNN(
mcmp,
agg,
ffn,
metrics=metric_list,
)
mcmpnn
[17]:
MulticomponentMPNN(
(message_passing): MulticomponentMessagePassing(
(blocks): ModuleList(
(0-1): 2 x BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
)
)
(agg): MeanAggregation()
(bn): Identity()
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0): RMSE(task_weights=[[1.0]])
(1): MAE(task_weights=[[1.0]])
(2): MSE(task_weights=[[1.0]])
)
)
Set up trainer#
[18]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True,
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20, # number of epochs to train for
)
๐ก Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Start training#
[19]:
trainer.fit(mcmpnn, train_loader, val_loader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
โโโโโณโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโณโโโโโโโโณโโโโโโโโ โ โ Name โ Type โ Params โ Mode โ FLOPs โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ 0 โ message_passing โ MulticomponentMessagePassing โ 455 K โ train โ 0 โ โ 1 โ agg โ MeanAggregation โ 0 โ train โ 0 โ โ 2 โ bn โ Identity โ 0 โ train โ 0 โ โ 3 โ predictor โ RegressionFFN โ 180 K โ train โ 0 โ โ 4 โ X_d_transform โ Identity โ 0 โ train โ 0 โ โ 5 โ metrics โ ModuleList โ 0 โ train โ 0 โ โโโโโดโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโดโโโโโโโโ
Trainable params: 636 K Non-trainable params: 0 Total params: 636 K Total estimated model params size (MB): 2 Modules in train mode: 35 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.
Test results#
[20]:
results = trainer.test(mcmpnn, test_loader, weights_only=False) # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ Test metric โ DataLoader 0 โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ test/mae โ 92.9779052734375 โ โ test/rmse โ 98.97355651855469 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.