Predicting#
Import packages#
[1]:
import pandas as pd
import numpy as np
import torch
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers, models
Change model input here#
[2]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.
Load model#
[3]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'graph_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['graph_transform'])`.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.
[3]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(W_d): Linear(in_features=300, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSELoss()
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
)
Change predict input here#
[4]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
smiles_column = 'smiles'
Load test smiles#
[5]:
df_test = pd.read_csv(test_path)
df_test
[5]:
smiles | lipo | |
---|---|---|
0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
... | ... | ... |
95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows × 2 columns
Get smiles#
[6]:
smis = df_test[smiles_column]
smis
[6]:
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...
...
95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...
96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...
97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...
98 COc1ccc(Cc2c(N)n[nH]c2N)cc1
99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...
Name: smiles, Length: 100, dtype: object
Get molecule datapoints#
[7]:
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]
Get molecule dataset#
[8]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)
test_loader = data.build_dataloader(test_dset, shuffle=False)
Set up trainer#
[9]:
with torch.inference_mode():
trainer = pl.Trainer(
logger=None,
enable_progress_bar=True,
accelerator="cpu",
devices=1
)
test_preds = trainer.predict(mpnn, test_loader)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Predicting DataLoader 0: 100%|██████████| 100/100 [00:00<00:00, 396.76it/s]
[10]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test
[10]:
smiles | lipo | pred | |
---|---|---|---|
0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 | 2.176904 |
1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 | 2.148450 |
2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 | 2.159459 |
3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 | 2.167359 |
4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 | 2.153605 |
... | ... | ... | ... |
95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 | 2.149804 |
96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 | 2.153695 |
97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 | 2.158461 |
98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 | 2.175282 |
99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 | 2.159477 |
100 rows × 3 columns