Predicting Regression - Multicomponent#
Import packages#
[13]:
import numpy as np
import pandas as pd
import torch
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers
from chemprop.models import multi
Change model input here#
[14]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.
Load model#
[15]:
mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)
mcmpnn
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'graph_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['graph_transform'])`.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.
[15]:
MulticomponentMPNN(
(message_passing): MulticomponentMessagePassing(
(blocks): ModuleList(
(0-1): 2 x BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(W_d): Linear(in_features=300, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSELoss()
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
)
Change predict input here#
[16]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv" # path to your .csv file containing SMILES strings to make predictions for
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
Load test smiles#
[17]:
df_test = pd.read_csv(test_path)
df_test
[17]:
smiles | solvent | peakwavs_max | |
---|---|---|---|
0 | CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | ClCCl | 642.0 |
1 | C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | ClCCl | 420.0 |
2 | CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | O | 544.0 |
3 | c1ccc2[nH]ccc2c1 | O | 290.0 |
4 | CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | ClC(Cl)Cl | 736.0 |
... | ... | ... | ... |
95 | COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | C1CCOC1 | 359.0 |
96 | COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | C1CCCCC1 | 386.0 |
97 | CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | CCO | 425.0 |
98 | Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | c1ccccc1 | 324.0 |
99 | Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | ClCCl | 391.0 |
100 rows × 3 columns
Get smiles#
[18]:
smiss = df_test[smiles_columns].values
smiss[:5]
[18]:
array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
'ClCCl'],
['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
'ClCCl'],
['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
'O'],
['c1ccc2[nH]ccc2c1', 'O'],
['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
'ClC(Cl)Cl']], dtype=object)
Get molecule datapoints#
[19]:
n_componenets = len(smiles_columns)
test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]
Get molecule datasets#
[20]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]
Get multicomponent dataset and data loader#
[21]:
test_mcdset = data.MulticomponentDataset(test_dsets)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)
Set up trainer#
[22]:
with torch.inference_mode():
trainer = pl.Trainer(
logger=None,
enable_progress_bar=True,
accelerator="auto",
devices=1
)
test_preds = trainer.predict(mcmpnn, test_loader)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Predicting DataLoader 0: 100%|██████████| 100/100 [00:00<00:00, 399.94it/s]
[23]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test
[23]:
smiles | solvent | peakwavs_max | pred | |
---|---|---|---|---|
0 | CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | ClCCl | 642.0 | 458.408508 |
1 | C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | ClCCl | 420.0 | 457.399109 |
2 | CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | O | 544.0 | 453.458466 |
3 | c1ccc2[nH]ccc2c1 | O | 290.0 | 453.070251 |
4 | CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | ClC(Cl)Cl | 736.0 | 461.637939 |
... | ... | ... | ... | ... |
95 | COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | C1CCOC1 | 359.0 | 459.446198 |
96 | COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | C1CCCCC1 | 386.0 | 462.069153 |
97 | CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | CCO | 425.0 | 458.131134 |
98 | Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | c1ccccc1 | 324.0 | 459.271179 |
99 | Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | ClCCl | 391.0 | 458.653809 |
100 rows × 4 columns