Training Regression - Multicomponent#
[1]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers, models, nn
from chemprop.nn import metrics
from chemprop.models import multi
Load data#
Change your data inputs here#
[2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv" # path to your data .csv file containing SMILES strings and target values
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
target_columns = ['peakwavs_max'] # list of names of the columns containing targets
Read data#
[3]:
df_input = pd.read_csv(input_path)
df_input
[3]:
smiles | solvent | peakwavs_max | |
---|---|---|---|
0 | CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | ClCCl | 642.0 |
1 | C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | ClCCl | 420.0 |
2 | CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | O | 544.0 |
3 | c1ccc2[nH]ccc2c1 | O | 290.0 |
4 | CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | ClC(Cl)Cl | 736.0 |
... | ... | ... | ... |
95 | COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | C1CCOC1 | 359.0 |
96 | COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | C1CCCCC1 | 386.0 |
97 | CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | CCO | 425.0 |
98 | Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | c1ccccc1 | 324.0 |
99 | Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | ClCCl | 391.0 |
100 rows × 3 columns
Get SMILES and targets#
[4]:
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values
[5]:
# Take a look at the first 5 SMILES strings and targets
smiss[:5], ys[:5]
[5]:
(array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
'ClCCl'],
['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
'ClCCl'],
['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
'O'],
['c1ccc2[nH]ccc2c1', 'O'],
['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
'ClC(Cl)Cl']], dtype=object),
array([[642.],
[420.],
[544.],
[290.],
[736.]]))
Make molecule datapoints#
Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component.
[6]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]
Split data#
Perform data splitting for training, validation, and testing#
[7]:
component_to_split_by = 0 # index of the component to use for structure based splits
mols = [d.mol for d in all_data[component_to_split_by]]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
Get MoleculeDataset for each components#
[8]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_columns))]
Construct multicomponent dataset and scale the targets#
[9]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)
Construct data loader#
[10]:
train_loader = data.build_dataloader(train_mcdset)
val_loader = data.build_dataloader(val_mcdset, shuffle=False)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)
Construct multicomponent MPNN#
MulticomponentMessagePassing#
blocks
: a list of message passing block used for each componentsn_components
: number of components
[11]:
mcmp = nn.MulticomponentMessagePassing(
blocks=[nn.BondMessagePassing() for _ in range(len(smiles_columns))],
n_components=len(smiles_columns),
)
Aggregation#
[12]:
agg = nn.MeanAggregation()
RegressionFFN#
[13]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.register_buffer("mean", torch.tensor(mean, dtype=torch.float).unsqueeze(0))
/home/hwpang/Projects/chemprop_v2_dev/chemprop/chemprop/nn/transforms.py:22: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.register_buffer("scale", torch.tensor(scale, dtype=torch.float).unsqueeze(0))
[14]:
ffn = nn.RegressionFFN(
input_dim=mcmp.output_dim,
output_transform=output_transform,
)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.
Metrics#
[15]:
metric_list = [metrics.RMSEMetric(), metrics.MAEMetric()] # Only the first metric is used for training and early stopping
MulticomponentMPNN#
[16]:
mcmpnn = multi.MulticomponentMPNN(
mcmp,
agg,
ffn,
metrics=metric_list,
)
mcmpnn
[16]:
MulticomponentMPNN(
(message_passing): MulticomponentMessagePassing(
(blocks): ModuleList(
(0-1): 2 x BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSELoss()
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
)
Set up trainer#
[24]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True,
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20, # number of epochs to train for
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Start training#
[25]:
trainer.fit(mcmpnn, train_loader, val_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
| Name | Type | Params
-----------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 455 K
1 | agg | MeanAggregation | 0
2 | bn | BatchNorm1d | 1.2 K
3 | predictor | RegressionFFN | 180 K
4 | X_d_transform | Identity | 0
| other params | n/a | 1
-----------------------------------------------------------------
637 K Trainable params
1 Non-trainable params
637 K Total params
2.549 Total estimated model params size (MB)
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 25.37it/s, train_loss=0.033, val_loss=480.0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 21.88it/s, train_loss=0.033, val_loss=480.0]
Test results#
[26]:
results = trainer.test(mcmpnn, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 239.66it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test/mae 66.03467559814453
test/rmse 78.33834838867188
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ ]: