Uncertainty Quantification#

Open In Colab

[1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

Import packages#

[2]:
import pandas as pd
import numpy as np
import torch
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

from chemprop import data, models, nn, uncertainty
from chemprop.models import save_model, load_model
from chemprop.cli.conf import NOW
from chemprop.cli.predict import find_models

%load_ext autoreload
%autoreload 2

Training#

Loda data#

[3]:
chemprop_dir = Path.cwd().parent
input_path = (
    chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
)  # path to your data .csv file
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, "smiles"].values
ys = df_input.loc[:, ["lipo"]].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
[4]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
[5]:
train_dset = data.MoleculeDataset(train_data[0])
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data[0])
val_dset.normalize_targets(scaler)

test_dset = data.MoleculeDataset(test_data[0])
[6]:
train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset, shuffle=False)
test_loader = data.build_dataloader(test_dset, shuffle=False)

Constructs MPNN#

  • A Message passing constructs molecular graphs using message passing to learn node-level hidden representations.

  • An Aggregation is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

  • A FFN takes the aggregated representations and make target predictions. To obtain uncertainty predictions, the FFN must be modified accordingly.

    For regression:

    • ffn = nn.RegressionFFN()

    • ffn = nn.MveFFN()

    • ffn = nn.EvidentialFFN()

    For classification:

    • ffn = nn.BinaryClassificationFFN()

    • ffn = nn.BinaryDirichletFFN()

    • ffn = nn.MulticlassClassificationFFN()

    • ffn = nn.MulticlassDirichletFFN()

    For spectral:

    • ffn = nn.SpectralFFN() # will be available in future version

[7]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
# Change to other predictor if needed.
ffn = nn.MveFFN(output_transform=output_transform)
mpnn = models.MPNN(mp, agg, ffn, batch_norm=False)
mpnn
[7]:
MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): Identity()
  (predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MVELoss(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): MSE(task_weights=[[1.0]])
    (1): MVELoss(task_weights=[[1.0]])
  )
)

Set up trainer#

[8]:
model_output_dir = Path(f"chemprop_training/{NOW}")
monitor_mode = "min" if mpnn.metrics[0].higher_is_better else "max"
checkpointing = ModelCheckpoint(
    model_output_dir / "checkpoints",
    "best-{epoch}-{val_loss:.2f}",
    "val_loss",
    mode=monitor_mode,
    save_last=True,
)
[9]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=False,
    accelerator="cpu",
    callbacks=[checkpointing],
    devices=1,
    max_epochs=20,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

Start training#

[10]:
trainer.fit(mpnn, train_loader, val_loader)
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name             Type                Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ message_passing │ BondMessagePassing │  227 K │ train │     0 │
│ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
│ 2 │ bn              │ Identity           │      0 │ train │     0 │
│ 3 │ predictor       │ MveFFN             │ 90.9 K │ train │     0 │
│ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
│ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
└───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 318 K
Non-trainable params: 0
Total params: 318 K
Total estimated model params size (MB): 1
Modules in train mode: 24
Modules in eval mode: 0
Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.

Save the best model#

[11]:
best_model_path = checkpointing.best_model_path
model = mpnn.__class__.load_from_checkpoint(best_model_path)
p_model = model_output_dir / "best.pt"
save_model(p_model, model)

Predicting#

Change model input here#

[12]:
chemprop_dir = Path.cwd().parent
test_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
df_test = pd.read_csv(test_path)
test_dset = data.MoleculeDataset(test_data[0])
test_loader = data.build_dataloader(test_dset, shuffle=False)
df_test
[12]:
smiles lipo
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10
... ... ...
95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20
96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04
97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49
98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20
99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00

100 rows × 2 columns

[13]:
# use the validation set from the training as the calibration set as an example
cal_dset = data.MoleculeDataset(val_data[0])
cal_loader = data.build_dataloader(cal_dset, shuffle=False)

Constructs uncertainty estimator#

An uncertianty estimator can make model predictions and associated uncertainty predictions.

Available options can be found in uncertainty.UncertaintyEstimatorRegistry.

[14]:
print(uncertainty.UncertaintyEstimatorRegistry)
ClassRegistry {
    'none': <class 'chemprop.uncertainty.estimator.NoUncertaintyEstimator'>,
    'mve': <class 'chemprop.uncertainty.estimator.MVEEstimator'>,
    'ensemble': <class 'chemprop.uncertainty.estimator.EnsembleEstimator'>,
    'classification': <class 'chemprop.uncertainty.estimator.ClassEstimator'>,
    'evidential-total': <class 'chemprop.uncertainty.estimator.EvidentialTotalEstimator'>,
    'evidential-epistemic': <class 'chemprop.uncertainty.estimator.EvidentialEpistemicEstimator'>,
    'evidential-aleatoric': <class 'chemprop.uncertainty.estimator.EvidentialAleatoricEstimator'>,
    'dropout': <class 'chemprop.uncertainty.estimator.DropoutEstimator'>,
    'classification-dirichlet': <class 'chemprop.uncertainty.estimator.ClassificationDirichletEstimator'>,
    'multiclass-dirichlet': <class 'chemprop.uncertainty.estimator.MulticlassDirichletEstimator'>,
    'quantile-regression': <class 'chemprop.uncertainty.estimator.QuantileRegressionEstimator'>
}
[15]:
unc_estimator = uncertainty.MVEEstimator()

Constructs uncertainty calibrator#

An uncertianty calibrator can calibrate the predicted uncertainties.

Available options can be found in uncertainty.UncertaintyCalibratorRegistry.

For regression:

  • ZScalingCalibrator

  • ZelikmanCalibrator

  • MVEWeightingCalibrator

  • RegressionConformalCalibrator

For binary classification:

  • PlattCalibrator

  • IsotonicCalibrator

  • MultilabelConformalCalibrator

For multiclass classification:

  • MulticlassConformalCalibrator

  • AdaptiveMulticlassConformalCalibrator

  • IsotonicMulticlassCalibrator

[16]:
print(uncertainty.UncertaintyCalibratorRegistry)
ClassRegistry {
    'zscaling': <class 'chemprop.uncertainty.calibrator.ZScalingCalibrator'>,
    'zelikman-interval': <class 'chemprop.uncertainty.calibrator.ZelikmanCalibrator'>,
    'mve-weighting': <class 'chemprop.uncertainty.calibrator.MVEWeightingCalibrator'>,
    'conformal-regression': <class 'chemprop.uncertainty.calibrator.RegressionConformalCalibrator'>,
    'platt': <class 'chemprop.uncertainty.calibrator.PlattCalibrator'>,
    'isotonic': <class 'chemprop.uncertainty.calibrator.IsotonicCalibrator'>,
    'conformal-multilabel': <class 'chemprop.uncertainty.calibrator.MultilabelConformalCalibrator'>,
    'conformal-multiclass': <class 'chemprop.uncertainty.calibrator.MulticlassConformalCalibrator'>,
    'conformal-adaptive': <class 'chemprop.uncertainty.calibrator.AdaptiveMulticlassConformalCalibrator'>,
    'isotonic-multiclass': <class 'chemprop.uncertainty.calibrator.IsotonicMulticlassCalibrator'>
}
[17]:
unc_calibrator = uncertainty.ZScalingCalibrator()

Constructs uncertainty evaluator#

An uncertianty evaluator can evaluates the quality of uncertainty estimates.

Available options can be found in uncertainty.UncertaintyEvaluatorRegistry.

For regression:

  • NLLRegressionEvaluator

  • CalibrationAreaEvaluator

  • ExpectedNormalizedErrorEvaluator

  • SpearmanEvaluator

  • RegressionConformalEvaluator

For binary classification:

  • NLLClassEvaluator

  • MultilabelConformalEvaluator

For multiclass classification:

  • NLLMulticlassEvaluator

  • MulticlassConformalEvaluator

[18]:
print(uncertainty.UncertaintyEvaluatorRegistry)
ClassRegistry {
    'nll-regression': <class 'chemprop.uncertainty.evaluator.NLLRegressionEvaluator'>,
    'miscalibration_area': <class 'chemprop.uncertainty.evaluator.CalibrationAreaEvaluator'>,
    'ence': <class 'chemprop.uncertainty.evaluator.ExpectedNormalizedErrorEvaluator'>,
    'spearman': <class 'chemprop.uncertainty.evaluator.SpearmanEvaluator'>,
    'conformal-coverage-regression': <class 'chemprop.uncertainty.evaluator.RegressionConformalEvaluator'>,
    'nll-classification': <class 'chemprop.uncertainty.evaluator.NLLClassEvaluator'>,
    'conformal-coverage-classification': <class 'chemprop.uncertainty.evaluator.MultilabelConformalEvaluator'>,
    'nll-multiclass': <class 'chemprop.uncertainty.evaluator.NLLMulticlassEvaluator'>,
    'conformal-coverage-multiclass': <class 'chemprop.uncertainty.evaluator.MulticlassConformalEvaluator'>
}
[19]:
unc_evaluators = [
    uncertainty.NLLRegressionEvaluator(),
    uncertainty.CalibrationAreaEvaluator(),
    uncertainty.ExpectedNormalizedErrorEvaluator(),
    uncertainty.SpearmanEvaluator(),
]

Load model#

[20]:
model_paths = find_models([model_output_dir])
models = [load_model(model_path, multicomponent=False) for model_path in model_paths]

Setup trainer#

[21]:
trainer = pl.Trainer(logger=False, enable_progress_bar=True, accelerator="cpu", devices=1)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

Make uncertainty estimation#

[22]:
test_predss, test_uncss = unc_estimator(test_loader, models, trainer)
test_preds = test_predss.mean(0)
test_uncs = test_uncss.mean(0)

df_test = pd.DataFrame(
    {
        "smiles": test_dset.smiles,
        "target": test_dset.Y.reshape(-1),
        "pred": test_preds.reshape(-1),
        "unc": test_uncs.reshape(-1),
    }
)

df_test
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[22]:
smiles target pred unc
0 Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... 2.06 1.813618 1.817690
1 O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... 1.92 1.825420 1.775566
2 CNCCCC12CCC(c3ccccc31)c1ccccc12 0.89 1.778415 1.795096
3 Oc1ncnc2scc(-c3ccsc3)c12 2.25 1.838990 1.866848
4 C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... 2.04 1.818925 1.750113
5 COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 3.13 1.822321 1.767510
6 O=C(COc1ccccc1)c1ccccc1 2.87 1.859192 1.750126
7 CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 1.10 1.815018 1.842683
8 N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 -0.16 1.815731 1.770754
9 COc1cnc(-c2ccccn2)nc1N(C)C 1.90 1.856991 1.782075

Apply uncertainty calibration#

[23]:
cal_predss, cal_uncss = unc_estimator(cal_loader, models, trainer)
average_cal_preds = cal_predss.mean(0)
average_cal_uncs = cal_uncss.mean(0)
cal_targets = cal_dset.Y
cal_mask = torch.from_numpy(np.isfinite(cal_targets))
cal_targets = np.nan_to_num(cal_targets, nan=0.0)
cal_targets = torch.from_numpy(cal_targets)
unc_calibrator.fit(average_cal_preds, average_cal_uncs, cal_targets, cal_mask)

cal_test_uncs = unc_calibrator.apply(test_uncs)
df_test["cal_unc"] = cal_test_uncs
df_test
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[23]:
smiles target pred unc cal_unc
0 Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... 2.06 1.813618 1.817690 1.920422
1 O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... 1.92 1.825420 1.775566 1.875917
2 CNCCCC12CCC(c3ccccc31)c1ccccc12 0.89 1.778415 1.795096 1.896550
3 Oc1ncnc2scc(-c3ccsc3)c12 2.25 1.838990 1.866848 1.972358
4 C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... 2.04 1.818925 1.750113 1.849025
5 COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 3.13 1.822321 1.767510 1.867406
6 O=C(COc1ccccc1)c1ccccc1 2.87 1.859192 1.750126 1.849039
7 CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 1.10 1.815018 1.842683 1.946827
8 N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 -0.16 1.815731 1.770754 1.870833
9 COc1cnc(-c2ccccn2)nc1N(C)C 1.90 1.856991 1.782075 1.882794

Evaluate predicted uncertainty#

[24]:
test_targets = test_dset.Y
test_mask = torch.from_numpy(np.isfinite(test_targets))
test_targets = np.nan_to_num(test_targets, nan=0.0)
test_targets = torch.from_numpy(test_targets)

for evaluator in unc_evaluators:
    evaluation = evaluator.evaluate(test_preds, cal_test_uncs, test_targets, test_mask)
    print(f"{evaluator.alias}: {evaluation.tolist()}")
nll-regression: [1.4570959966142754]
miscalibration_area: [0.147599995136261]
ence: [0.5853198323227037]
spearman: [-0.18787875771522522]
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)