Predicting#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Import packages#
[2]:
import pandas as pd
import numpy as np
import torch
from lightning import pytorch as pl
from pathlib import Path
from chemprop import data, featurizers, models
Change model input here#
[3]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.
Load model#
[4]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn
[4]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
Change predict input here#
[5]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
smiles_column = 'smiles'
Load test smiles#
[6]:
df_test = pd.read_csv(test_path)
df_test
[6]:
| smiles | lipo | |
|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
| ... | ... | ... |
| 95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
| 96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
| 97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
| 98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
| 99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows × 2 columns
Get smiles#
[7]:
smis = df_test[smiles_column]
smis
[7]:
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...
...
95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...
96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...
97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...
98 COc1ccc(Cc2c(N)n[nH]c2N)cc1
99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...
Name: smiles, Length: 100, dtype: object
Get molecule datapoints#
[8]:
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]
Get molecule dataset#
[9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)
test_loader = data.build_dataloader(test_dset, shuffle=False)
Set up trainer#
[10]:
with torch.inference_mode():
trainer = pl.Trainer(
logger=None,
enable_progress_bar=True,
accelerator="cpu",
devices=1
)
test_preds = trainer.predict(mpnn, test_loader)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[11]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test
[11]:
| smiles | lipo | pred | |
|---|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 | 2.253542 |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 | 2.235016 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 | 2.245891 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 | 2.249847 |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 | 2.228097 |
| ... | ... | ... | ... |
| 95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 | 2.233408 |
| 96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 | 2.236931 |
| 97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 | 2.237789 |
| 98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 | 2.252625 |
| 99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 | 2.235702 |
100 rows × 3 columns