Predicting Regression - Multicomponent#

Import packages#

[13]:
import numpy as np
import pandas as pd
import torch
from lightning import pytorch as pl
from pathlib import Path

from chemprop import data, featurizers
from chemprop.models import multi

Change model input here#

[14]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

Load model#

[15]:
mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)
mcmpnn
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'graph_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['graph_transform'])`.
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'output_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['output_transform'])`.
[15]:
MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (W_d): Linear(in_features=300, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): GraphTransform(
          (V_transform): Identity()
          (E_transform): Identity()
        )
      )
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss()
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
)

Change predict input here#

[16]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv" # path to your .csv file containing SMILES strings to make predictions for
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings

Load test smiles#

[17]:
df_test = pd.read_csv(test_path)
df_test
[17]:
smiles solvent peakwavs_max
0 CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... ClCCl 642.0
1 C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... ClCCl 420.0
2 CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... O 544.0
3 c1ccc2[nH]ccc2c1 O 290.0
4 CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... ClC(Cl)Cl 736.0
... ... ... ...
95 COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... C1CCOC1 359.0
96 COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... C1CCCCC1 386.0
97 CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O CCO 425.0
98 Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... c1ccccc1 324.0
99 Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... ClCCl 391.0

100 rows × 3 columns

Get smiles#

[18]:
smiss = df_test[smiles_columns].values
smiss[:5]
[18]:
array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
        'ClCCl'],
       ['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
        'ClCCl'],
       ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
        'O'],
       ['c1ccc2[nH]ccc2c1', 'O'],
       ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
        'ClC(Cl)Cl']], dtype=object)

Get molecule datapoints#

[19]:
n_componenets = len(smiles_columns)
test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]

Get molecule datasets#

[20]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]

Get multicomponent dataset and data loader#

[21]:
test_mcdset = data.MulticomponentDataset(test_dsets)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)

Set up trainer#

[22]:
with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1
    )
    test_preds = trainer.predict(mcmpnn, test_loader)
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/hwpang/miniforge3/envs/chemprop_v2_dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Predicting DataLoader 0: 100%|██████████| 100/100 [00:00<00:00, 399.94it/s]
[23]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test
[23]:
smiles solvent peakwavs_max pred
0 CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... ClCCl 642.0 458.408508
1 C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... ClCCl 420.0 457.399109
2 CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... O 544.0 453.458466
3 c1ccc2[nH]ccc2c1 O 290.0 453.070251
4 CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... ClC(Cl)Cl 736.0 461.637939
... ... ... ... ...
95 COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... C1CCOC1 359.0 459.446198
96 COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... C1CCCCC1 386.0 462.069153
97 CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O CCO 425.0 458.131134
98 Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... c1ccccc1 324.0 459.271179
99 Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... ClCCl 391.0 458.653809

100 rows × 4 columns