Uncertainty Quantification#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Import packages#
[2]:
import pandas as pd
import numpy as np
import torch
import pandas as pd
from pathlib import Path
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from chemprop import data, models, nn, uncertainty
from chemprop.models import save_model, load_model
from chemprop.cli.conf import NOW
from chemprop.cli.predict import find_models
%load_ext autoreload
%autoreload 2
Training#
Loda data#
[3]:
chemprop_dir = Path.cwd().parent
input_path = (
chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
) # path to your data .csv file
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, "smiles"].values
ys = df_input.loc[:, ["lipo"]].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
[4]:
mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
[5]:
train_dset = data.MoleculeDataset(train_data[0])
scaler = train_dset.normalize_targets()
val_dset = data.MoleculeDataset(val_data[0])
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data[0])
[6]:
train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset, shuffle=False)
test_loader = data.build_dataloader(test_dset, shuffle=False)
Constructs MPNN#
A
Message passingconstructs molecular graphs using message passing to learn node-level hidden representations.An
Aggregationis responsible for constructing a graph-level representation from the set of node-level representations after message passing.A
FFNtakes the aggregated representations and make target predictions. To obtain uncertainty predictions, theFFNmust be modified accordingly.For regression:
ffn = nn.RegressionFFN()ffn = nn.MveFFN()ffn = nn.EvidentialFFN()
For classification:
ffn = nn.BinaryClassificationFFN()ffn = nn.BinaryDirichletFFN()ffn = nn.MulticlassClassificationFFN()ffn = nn.MulticlassDirichletFFN()
For spectral:
ffn = nn.SpectralFFN()# will be available in future version
[7]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
# Change to other predictor if needed.
ffn = nn.MveFFN(output_transform=output_transform)
mpnn = models.MPNN(mp, agg, ffn, batch_norm=False)
mpnn
[7]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): MeanAggregation()
(bn): Identity()
(predictor): MveFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=2, bias=True)
)
)
(criterion): MVELoss(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0): MSE(task_weights=[[1.0]])
(1): MVELoss(task_weights=[[1.0]])
)
)
Set up trainer#
[8]:
model_output_dir = Path(f"chemprop_training/{NOW}")
monitor_mode = "min" if mpnn.metrics[0].higher_is_better else "max"
checkpointing = ModelCheckpoint(
model_output_dir / "checkpoints",
"best-{epoch}-{val_loss:.2f}",
"val_loss",
mode=monitor_mode,
save_last=True,
)
[9]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True,
enable_progress_bar=False,
accelerator="cpu",
callbacks=[checkpointing],
devices=1,
max_epochs=20,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Start training#
[10]:
trainer.fit(mpnn, train_loader, val_loader)
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ ┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ │ 0 │ message_passing │ BondMessagePassing │ 227 K │ train │ 0 │ │ 1 │ agg │ MeanAggregation │ 0 │ train │ 0 │ │ 2 │ bn │ Identity │ 0 │ train │ 0 │ │ 3 │ predictor │ MveFFN │ 90.9 K │ train │ 0 │ │ 4 │ X_d_transform │ Identity │ 0 │ train │ 0 │ │ 5 │ metrics │ ModuleList │ 0 │ train │ 0 │ └───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 318 K Non-trainable params: 0 Total params: 318 K Total estimated model params size (MB): 1 Modules in train mode: 24 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.
Save the best model#
[11]:
best_model_path = checkpointing.best_model_path
model = mpnn.__class__.load_from_checkpoint(best_model_path)
p_model = model_output_dir / "best.pt"
save_model(p_model, model)
Predicting#
Change model input here#
[12]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
df_test = pd.read_csv(test_path)
test_dset = data.MoleculeDataset(test_data[0])
test_loader = data.build_dataloader(test_dset, shuffle=False)
df_test
[12]:
| smiles | lipo | |
|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
| ... | ... | ... |
| 95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
| 96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
| 97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
| 98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
| 99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows × 2 columns
[13]:
# use the validation set from the training as the calibration set as an example
cal_dset = data.MoleculeDataset(val_data[0])
cal_loader = data.build_dataloader(cal_dset, shuffle=False)
Constructs uncertainty estimator#
An uncertianty estimator can make model predictions and associated uncertainty predictions.
Available options can be found in uncertainty.UncertaintyEstimatorRegistry.
[14]:
print(uncertainty.UncertaintyEstimatorRegistry)
ClassRegistry {
'none': <class 'chemprop.uncertainty.estimator.NoUncertaintyEstimator'>,
'mve': <class 'chemprop.uncertainty.estimator.MVEEstimator'>,
'ensemble': <class 'chemprop.uncertainty.estimator.EnsembleEstimator'>,
'classification': <class 'chemprop.uncertainty.estimator.ClassEstimator'>,
'evidential-total': <class 'chemprop.uncertainty.estimator.EvidentialTotalEstimator'>,
'evidential-epistemic': <class 'chemprop.uncertainty.estimator.EvidentialEpistemicEstimator'>,
'evidential-aleatoric': <class 'chemprop.uncertainty.estimator.EvidentialAleatoricEstimator'>,
'dropout': <class 'chemprop.uncertainty.estimator.DropoutEstimator'>,
'classification-dirichlet': <class 'chemprop.uncertainty.estimator.ClassificationDirichletEstimator'>,
'multiclass-dirichlet': <class 'chemprop.uncertainty.estimator.MulticlassDirichletEstimator'>,
'quantile-regression': <class 'chemprop.uncertainty.estimator.QuantileRegressionEstimator'>
}
[15]:
unc_estimator = uncertainty.MVEEstimator()
Constructs uncertainty calibrator#
An uncertianty calibrator can calibrate the predicted uncertainties.
Available options can be found in uncertainty.UncertaintyCalibratorRegistry.
For regression:
ZScalingCalibrator
ZelikmanCalibrator
MVEWeightingCalibrator
RegressionConformalCalibrator
For binary classification:
PlattCalibrator
IsotonicCalibrator
MultilabelConformalCalibrator
For multiclass classification:
MulticlassConformalCalibrator
AdaptiveMulticlassConformalCalibrator
IsotonicMulticlassCalibrator
[16]:
print(uncertainty.UncertaintyCalibratorRegistry)
ClassRegistry {
'zscaling': <class 'chemprop.uncertainty.calibrator.ZScalingCalibrator'>,
'zelikman-interval': <class 'chemprop.uncertainty.calibrator.ZelikmanCalibrator'>,
'mve-weighting': <class 'chemprop.uncertainty.calibrator.MVEWeightingCalibrator'>,
'conformal-regression': <class 'chemprop.uncertainty.calibrator.RegressionConformalCalibrator'>,
'platt': <class 'chemprop.uncertainty.calibrator.PlattCalibrator'>,
'isotonic': <class 'chemprop.uncertainty.calibrator.IsotonicCalibrator'>,
'conformal-multilabel': <class 'chemprop.uncertainty.calibrator.MultilabelConformalCalibrator'>,
'conformal-multiclass': <class 'chemprop.uncertainty.calibrator.MulticlassConformalCalibrator'>,
'conformal-adaptive': <class 'chemprop.uncertainty.calibrator.AdaptiveMulticlassConformalCalibrator'>,
'isotonic-multiclass': <class 'chemprop.uncertainty.calibrator.IsotonicMulticlassCalibrator'>
}
[17]:
unc_calibrator = uncertainty.ZScalingCalibrator()
Constructs uncertainty evaluator#
An uncertianty evaluator can evaluates the quality of uncertainty estimates.
Available options can be found in uncertainty.UncertaintyEvaluatorRegistry.
For regression:
NLLRegressionEvaluator
CalibrationAreaEvaluator
ExpectedNormalizedErrorEvaluator
SpearmanEvaluator
RegressionConformalEvaluator
For binary classification:
NLLClassEvaluator
MultilabelConformalEvaluator
For multiclass classification:
NLLMulticlassEvaluator
MulticlassConformalEvaluator
[18]:
print(uncertainty.UncertaintyEvaluatorRegistry)
ClassRegistry {
'nll-regression': <class 'chemprop.uncertainty.evaluator.NLLRegressionEvaluator'>,
'miscalibration_area': <class 'chemprop.uncertainty.evaluator.CalibrationAreaEvaluator'>,
'ence': <class 'chemprop.uncertainty.evaluator.ExpectedNormalizedErrorEvaluator'>,
'spearman': <class 'chemprop.uncertainty.evaluator.SpearmanEvaluator'>,
'conformal-coverage-regression': <class 'chemprop.uncertainty.evaluator.RegressionConformalEvaluator'>,
'nll-classification': <class 'chemprop.uncertainty.evaluator.NLLClassEvaluator'>,
'conformal-coverage-classification': <class 'chemprop.uncertainty.evaluator.MultilabelConformalEvaluator'>,
'nll-multiclass': <class 'chemprop.uncertainty.evaluator.NLLMulticlassEvaluator'>,
'conformal-coverage-multiclass': <class 'chemprop.uncertainty.evaluator.MulticlassConformalEvaluator'>
}
[19]:
unc_evaluators = [
uncertainty.NLLRegressionEvaluator(),
uncertainty.CalibrationAreaEvaluator(),
uncertainty.ExpectedNormalizedErrorEvaluator(),
uncertainty.SpearmanEvaluator(),
]
Load model#
[20]:
model_paths = find_models([model_output_dir])
models = [load_model(model_path, multicomponent=False) for model_path in model_paths]
Setup trainer#
[21]:
trainer = pl.Trainer(logger=False, enable_progress_bar=True, accelerator="cpu", devices=1)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Make uncertainty estimation#
[22]:
test_predss, test_uncss = unc_estimator(test_loader, models, trainer)
test_preds = test_predss.mean(0)
test_uncs = test_uncss.mean(0)
df_test = pd.DataFrame(
{
"smiles": test_dset.smiles,
"target": test_dset.Y.reshape(-1),
"pred": test_preds.reshape(-1),
"unc": test_uncs.reshape(-1),
}
)
df_test
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[22]:
| smiles | target | pred | unc | |
|---|---|---|---|---|
| 0 | Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... | 2.06 | 1.813618 | 1.817690 |
| 1 | O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... | 1.92 | 1.825420 | 1.775566 |
| 2 | CNCCCC12CCC(c3ccccc31)c1ccccc12 | 0.89 | 1.778415 | 1.795096 |
| 3 | Oc1ncnc2scc(-c3ccsc3)c12 | 2.25 | 1.838990 | 1.866848 |
| 4 | C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... | 2.04 | 1.818925 | 1.750113 |
| 5 | COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 | 3.13 | 1.822321 | 1.767510 |
| 6 | O=C(COc1ccccc1)c1ccccc1 | 2.87 | 1.859192 | 1.750126 |
| 7 | CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 | 1.10 | 1.815018 | 1.842683 |
| 8 | N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 | -0.16 | 1.815731 | 1.770754 |
| 9 | COc1cnc(-c2ccccn2)nc1N(C)C | 1.90 | 1.856991 | 1.782075 |
Apply uncertainty calibration#
[23]:
cal_predss, cal_uncss = unc_estimator(cal_loader, models, trainer)
average_cal_preds = cal_predss.mean(0)
average_cal_uncs = cal_uncss.mean(0)
cal_targets = cal_dset.Y
cal_mask = torch.from_numpy(np.isfinite(cal_targets))
cal_targets = np.nan_to_num(cal_targets, nan=0.0)
cal_targets = torch.from_numpy(cal_targets)
unc_calibrator.fit(average_cal_preds, average_cal_uncs, cal_targets, cal_mask)
cal_test_uncs = unc_calibrator.apply(test_uncs)
df_test["cal_unc"] = cal_test_uncs
df_test
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[23]:
| smiles | target | pred | unc | cal_unc | |
|---|---|---|---|---|---|
| 0 | Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... | 2.06 | 1.813618 | 1.817690 | 1.920422 |
| 1 | O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... | 1.92 | 1.825420 | 1.775566 | 1.875917 |
| 2 | CNCCCC12CCC(c3ccccc31)c1ccccc12 | 0.89 | 1.778415 | 1.795096 | 1.896550 |
| 3 | Oc1ncnc2scc(-c3ccsc3)c12 | 2.25 | 1.838990 | 1.866848 | 1.972358 |
| 4 | C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... | 2.04 | 1.818925 | 1.750113 | 1.849025 |
| 5 | COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 | 3.13 | 1.822321 | 1.767510 | 1.867406 |
| 6 | O=C(COc1ccccc1)c1ccccc1 | 2.87 | 1.859192 | 1.750126 | 1.849039 |
| 7 | CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 | 1.10 | 1.815018 | 1.842683 | 1.946827 |
| 8 | N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 | -0.16 | 1.815731 | 1.770754 | 1.870833 |
| 9 | COc1cnc(-c2ccccn2)nc1N(C)C | 1.90 | 1.856991 | 1.782075 | 1.882794 |
Evaluate predicted uncertainty#
[24]:
test_targets = test_dset.Y
test_mask = torch.from_numpy(np.isfinite(test_targets))
test_targets = np.nan_to_num(test_targets, nan=0.0)
test_targets = torch.from_numpy(test_targets)
for evaluator in unc_evaluators:
evaluation = evaluator.evaluate(test_preds, cal_test_uncs, test_targets, test_mask)
print(f"{evaluator.alias}: {evaluation.tolist()}")
nll-regression: [1.4570959966142754]
miscalibration_area: [0.147599995136261]
ence: [0.5853198323227037]
spearman: [-0.18787875771522522]
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs)