Source code for chemprop.uncertainty.estimator

from abc import ABC, abstractmethod
import types
from typing import Iterable

from lightning import pytorch as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader

from chemprop.conf import LIGHTNING_26_COMPAT_ARGS
from chemprop.models import MPNN, MolAtomBondMPNN
from chemprop.utils.registry import ClassRegistry


[docs] class UncertaintyEstimator(ABC): """A helper class for making model predictions and associated uncertainty predictions."""
[docs] @abstractmethod def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): """ Calculate the uncalibrated predictions and uncertainties for the dataloader. dataloader: DataLoader the dataloader used for model predictions and uncertainty predictions models: Iterable[MPNN] | Iterable[MolAtomBondMPNN] the models used for model predictions and uncertainty predictions. If using MolAtomBondMPNN models, the uncertainty estimator will return preds and uncs for each of the mole, atom, and bond predictions and uncertainties. trainer: pl.Trainer an instance of the :class:`~lightning.pytorch.trainer.trainer.Trainer` used to manage model inference Returns ------- preds : Tensor the model predictions, with shape varying by task type: * regression/binary classification: ``m x n x t`` * multiclass classification: ``m x n x t x c``, where ``m`` is the number of models, ``n`` is the number of inputs, ``t`` is the number of tasks, and ``c`` is the number of classes. uncs : Tensor the predicted uncertainties, with shapes of ``m' x n x t``. .. note:: The ``m`` and ``m'`` are different by definition. The ``m`` is the number of models, while the ``m'`` is the number of uncertainty estimations. For example, if two MVE or evidential models are provided, both ``m`` and ``m'`` are two. However, for an ensemble of two models, ``m'`` would be one (even though ``m = 2``). """
UncertaintyEstimatorRegistry = ClassRegistry[UncertaintyEstimator]()
[docs] @UncertaintyEstimatorRegistry.register("none") class NoUncertaintyEstimator(UncertaintyEstimator):
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: predss = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) predss.append(preds) else: mol_predss = [] atom_predss = [] bond_predss = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_predss.append(mol_preds) if atom_preds is not None: atom_predss.append(atom_preds) if bond_preds is not None: bond_predss.append(bond_preds) preds_tuple = (predss,) if not_mol_atom_bond else (mol_predss, atom_predss, bond_predss) processed_preds = [] for raw_preds in preds_tuple: if raw_preds: processed_preds.append(torch.stack(raw_preds)) else: processed_preds.append(None) if not_mol_atom_bond: return processed_preds[0], None return processed_preds, (None, None, None)
[docs] @UncertaintyEstimatorRegistry.register("mve") class MVEEstimator(UncertaintyEstimator): """ Class that estimates prediction means and variances (MVE). [nix1994]_ References ---------- .. [nix1994] Nix, D. A.; Weigend, A. S. "Estimating the mean and variance of the target probability distribution." Proceedings of 1994 IEEE International Conference on Neural Networks, 1994 https://doi.org/10.1109/icnn.1994.374138 """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: mves = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) mves.append(preds) else: mol_mves = [] atom_mves = [] bond_mves = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_mves.append(mol_preds) if atom_preds is not None: atom_mves.append(atom_preds) if bond_preds is not None: bond_mves.append(bond_preds) mves_tuple = (mves,) if not_mol_atom_bond else (mol_mves, atom_mves, bond_mves) means = [] vars = [] for raw_mves in mves_tuple: if raw_mves: mves = torch.stack(raw_mves, dim=0) mean, var = mves.unbind(dim=-1) means.append(mean) vars.append(var) else: means.append(None) vars.append(None) if not_mol_atom_bond: return means[0], vars[0] return means, vars
[docs] @UncertaintyEstimatorRegistry.register("ensemble") class EnsembleEstimator(UncertaintyEstimator): """ Class that predicts the uncertainty of predictions based on the variance in predictions among an ensemble's submodels. """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): if len(models) <= 1: raise ValueError( "Ensemble method for uncertainty is only available when multiple models are provided." ) not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: ensemble_preds = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) ensemble_preds.append(preds) else: mol_ensemble_preds = [] atom_ensemble_preds = [] bond_ensemble_preds = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_ensemble_preds.append(mol_preds) if atom_preds is not None: atom_ensemble_preds.append(atom_preds) if bond_preds is not None: bond_ensemble_preds.append(bond_preds) ensemble_preds_tuple = ( (ensemble_preds,) if not_mol_atom_bond else (mol_ensemble_preds, atom_ensemble_preds, bond_ensemble_preds) ) stacked_predss = [] varss = [] for ensemble_preds in ensemble_preds_tuple: if ensemble_preds: stacked_preds = torch.stack(ensemble_preds).float() vars = torch.var(stacked_preds, dim=0, correction=0).unsqueeze(0) stacked_predss.append(stacked_preds) varss.append(vars) else: stacked_predss.append(None) varss.append(None) if not_mol_atom_bond: return stacked_predss[0], varss[0] return stacked_predss, varss
[docs] @UncertaintyEstimatorRegistry.register("classification") class ClassEstimator(UncertaintyEstimator):
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: predss = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) predss.append(preds) else: mol_predss = [] atom_predss = [] bond_predss = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_predss.append(mol_preds) if atom_preds is not None: atom_predss.append(atom_preds) if bond_preds is not None: bond_predss.append(bond_preds) predss_tuple = (predss,) if not_mol_atom_bond else (mol_predss, atom_predss, bond_predss) processed_predss = [] for raw_preds in predss_tuple: if raw_preds: processed_predss.append(torch.stack(raw_preds)) else: processed_predss.append(None) if not_mol_atom_bond: return processed_predss[0], processed_predss[0] return processed_predss, processed_predss
[docs] @UncertaintyEstimatorRegistry.register("evidential-total") class EvidentialTotalEstimator(UncertaintyEstimator): """ Class that predicts the total evidential uncertainty based on hyperparameters of the evidential distribution [amini2020]_. References ----------- .. [amini2020] Amini, A.; Schwarting, W.; Soleimany, A.; Rus, D. "Deep Evidential Regression". NeurIPS, 2020. https://arxiv.org/abs/1910.02600 """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: uncs = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) uncs.append(preds) else: mol_uncs = [] atom_uncs = [] bond_uncs = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_uncs.append(mol_preds) if atom_preds is not None: atom_uncs.append(atom_preds) if bond_preds is not None: bond_uncs.append(bond_preds) uncs_tuple = (uncs,) if not_mol_atom_bond else (mol_uncs, atom_uncs, bond_uncs) means = [] total_uncss = [] for raw_uncs in uncs_tuple: if raw_uncs: uncs = torch.stack(raw_uncs) mean, v, alpha, beta = uncs.unbind(-1) total_uncs = (1 + 1 / v) * (beta / (alpha - 1)) means.append(mean) total_uncss.append(total_uncs) else: means.append(None) total_uncss.append(None) if not_mol_atom_bond: return means[0], total_uncss[0] return means, total_uncss
[docs] @UncertaintyEstimatorRegistry.register("evidential-epistemic") class EvidentialEpistemicEstimator(UncertaintyEstimator): """ Class that predicts the epistemic evidential uncertainty based on hyperparameters of the evidential distribution. """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: uncs = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) uncs.append(preds) else: mol_uncs = [] atom_uncs = [] bond_uncs = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_uncs.append(mol_preds) if atom_preds is not None: atom_uncs.append(atom_preds) if bond_preds is not None: bond_uncs.append(bond_preds) uncs_tuple = (uncs,) if not_mol_atom_bond else (mol_uncs, atom_uncs, bond_uncs) means = [] epistemic_uncss = [] for raw_uncs in uncs_tuple: if raw_uncs: uncs = torch.stack(raw_uncs) mean, v, alpha, beta = uncs.unbind(-1) epistemic_uncs = (1 / v) * (beta / (alpha - 1)) means.append(mean) epistemic_uncss.append(epistemic_uncs) else: means.append(None) epistemic_uncss.append(None) if not_mol_atom_bond: return means[0], epistemic_uncss[0] return means, epistemic_uncss
[docs] @UncertaintyEstimatorRegistry.register("evidential-aleatoric") class EvidentialAleatoricEstimator(UncertaintyEstimator): """ Class that predicts the aleatoric evidential uncertainty based on hyperparameters of the evidential distribution. """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: uncs = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) uncs.append(preds) else: mol_uncs = [] atom_uncs = [] bond_uncs = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_uncs.append(mol_preds) if atom_preds is not None: atom_uncs.append(atom_preds) if bond_preds is not None: bond_uncs.append(bond_preds) uncs_tuple = (uncs,) if not_mol_atom_bond else (mol_uncs, atom_uncs, bond_uncs) means = [] aleatoric_uncss = [] for raw_uncs in uncs_tuple: if raw_uncs: uncs = torch.stack(raw_uncs) mean, v, alpha, beta = uncs.unbind(-1) aleatoric_uncs = beta / (alpha - 1) means.append(mean) aleatoric_uncss.append(aleatoric_uncs) else: means.append(None) aleatoric_uncss.append(None) if not_mol_atom_bond: return means[0], aleatoric_uncss[0] return means, aleatoric_uncss
[docs] @UncertaintyEstimatorRegistry.register("dropout") class DropoutEstimator(UncertaintyEstimator): """ A :class:`DropoutEstimator` creates a virtual ensemble of models via Monte Carlo dropout with the provided model [gal2016]_. Parameters ---------- ensemble_size: int The number of samples to draw for the ensemble. dropout: float | None The probability of dropping out units in the dropout layers. If unspecified, the training probability is used, which is prefered but not possible if the model was not trained with dropout (i.e. p=0). References ----------- .. [gal2016] Gal, Y.; Ghahramani, Z. "Dropout as a bayesian approximation: Representing model uncertainty in deep learning." International conference on machine learning. PMLR, 2016. https://arxiv.org/abs/1506.02142 """ def __init__(self, ensemble_size: int, dropout: None | float = None): self.ensemble_size = ensemble_size self.dropout = dropout
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: meanss, varss = [], [] for model in models: self._setup_model(model) individual_preds = [] for _ in range(self.ensemble_size): predss = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) preds = torch.concat(predss, 0) individual_preds.append(preds) stacked_preds = torch.stack(individual_preds, dim=0).float() means = torch.mean(stacked_preds, dim=0) vars = torch.var(stacked_preds, dim=0, correction=0) self._restore_model(model) meanss.append(means) varss.append(vars) return torch.stack(meanss), torch.stack(varss) else: mol_meanss, mol_varss = [], [] atom_meanss, atom_varss = [], [] bond_meanss, bond_varss = [], [] for model in models: self._setup_model(model) mol_individual_preds = [] atom_individual_preds = [] bond_individual_preds = [] for _ in range(self.ensemble_size): MAB_predss = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_predss) ) if mol_preds is not None: mol_individual_preds.append(mol_preds) if atom_preds is not None: atom_individual_preds.append(atom_preds) if bond_preds is not None: bond_individual_preds.append(bond_preds) if mol_individual_preds: stacked_mol_preds = torch.stack(mol_individual_preds, dim=0).float() mol_means = torch.mean(stacked_mol_preds, dim=0) mol_vars = torch.var(stacked_mol_preds, dim=0, correction=0) mol_meanss.append(mol_means) mol_varss.append(mol_vars) if atom_individual_preds: stacked_atom_preds = torch.stack(atom_individual_preds, dim=0).float() atom_means = torch.mean(stacked_atom_preds, dim=0) atom_vars = torch.var(stacked_atom_preds, dim=0, correction=0) atom_meanss.append(atom_means) atom_varss.append(atom_vars) if bond_individual_preds: stacked_bond_preds = torch.stack(bond_individual_preds, dim=0).float() bond_means = torch.mean(stacked_bond_preds, dim=0) bond_vars = torch.var(stacked_bond_preds, dim=0, correction=0) bond_meanss.append(bond_means) bond_varss.append(bond_vars) self._restore_model(model) return ( ( torch.stack(mol_meanss) if mol_meanss else None, torch.stack(atom_meanss) if atom_meanss else None, torch.stack(bond_meanss) if bond_meanss else None, ), ( torch.stack(mol_varss) if mol_varss else None, torch.stack(atom_varss) if atom_varss else None, torch.stack(bond_varss) if bond_varss else None, ), )
def _setup_model(self, model): model._predict_step = model.predict_step model.predict_step = self._predict_step(model) model.apply(self._change_dropout) def _restore_model(self, model): model.predict_step = model._predict_step del model._predict_step model.apply(self._restore_dropout) def _predict_step(self, model): def _wrapped_predict_step(*args, **kwargs): model.apply(self._activate_dropout) return model._predict_step(*args, **kwargs) return _wrapped_predict_step def _activate_dropout(self, module): if isinstance(module, torch.nn.Dropout): module.train() def _change_dropout(self, module): if isinstance(module, torch.nn.Dropout): module._p = module.p if self.dropout: module.p = self.dropout def _restore_dropout(self, module): if isinstance(module, torch.nn.Dropout): if hasattr(module, "_p"): module.p = module._p del module._p
# TODO: Add in v2.1.x # @UncertaintyEstimatorRegistry.register("spectra-roundrobin") # class RoundRobinSpectraEstimator(UncertaintyEstimator): # def __call__( # self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer # ) -> tuple[Tensor, Tensor]: # return
[docs] @UncertaintyEstimatorRegistry.register("classification-dirichlet") class ClassificationDirichletEstimator(UncertaintyEstimator): r""" A :class:`ClassificationDirichletEstimator` predicts an amount of 'evidence' for both the negative class and the positive class as described in [sensoy2018]_. The class probabilities and the uncertainty are calculated based on the evidence. .. math:: S = \sum_{i=1}^K \alpha_i p_i = \alpha_i / S u = K / S where :math:`K` is the number of classes, :math:`\alpha_i` is the evidence for class :math:`i`, :math:`p_i` is the probability of class :math:`i`, and :math:`u` is the uncertainty. References ---------- .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768 """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: uncs = [] for model in models: preds = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) uncs.append(preds) else: mol_uncs = [] atom_uncs = [] bond_uncs = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_uncs.append(mol_preds) if atom_preds is not None: atom_uncs.append(atom_preds) if bond_preds is not None: bond_uncs.append(bond_preds) uncs_tuple = (uncs,) if not_mol_atom_bond else (mol_uncs, atom_uncs, bond_uncs) ys = [] us = [] for raw_uncs in uncs_tuple: if raw_uncs: uncs = torch.stack(raw_uncs, dim=0) y, u = uncs.unbind(dim=-1) ys.append(y) us.append(u) else: ys.append(None) us.append(None) if not_mol_atom_bond: return ys[0], us[0] return ys, us
[docs] @UncertaintyEstimatorRegistry.register("multiclass-dirichlet") class MulticlassDirichletEstimator(UncertaintyEstimator): r""" A :class:`MulticlassDirichletEstimator` predicts an amount of 'evidence' for each class as described in [sensoy2018]_. The class probabilities and the uncertainty are calculated based on the evidence. .. math:: S = \sum_{i=1}^K \alpha_i p_i = \alpha_i / S u = K / S where :math:`K` is the number of classes, :math:`\alpha_i` is the evidence for class :math:`i`, :math:`p_i` is the probability of class :math:`i`, and :math:`u` is the uncertainty. References ---------- .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768 """
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: preds = [] uncs = [] for model in models: self._setup_model(model) output = torch.concat( trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS), 0 ) self._restore_model(model) preds.append(output[..., :-1]) uncs.append(output[..., -1]) preds = torch.stack(preds, 0) uncs = torch.stack(uncs, 0) return preds, uncs else: mol_preds = [] atom_preds = [] bond_preds = [] mol_uncs = [] atom_uncs = [] bond_uncs = [] for model in models: self._setup_model(model) MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_output, atom_output, bond_output = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) self._restore_model(model) if mol_output is not None: mol_preds.append(mol_output[..., :-1]) mol_uncs.append(mol_output[..., -1]) if atom_output is not None: atom_preds.append(atom_output[..., :-1]) atom_uncs.append(atom_output[..., -1]) if bond_output is not None: bond_preds.append(bond_output[..., :-1]) bond_uncs.append(bond_output[..., -1]) mol_preds = torch.stack(mol_preds, 0) if mol_preds else None atom_preds = torch.stack(atom_preds, 0) if atom_preds else None bond_preds = torch.stack(bond_preds, 0) if bond_preds else None mol_uncs = torch.stack(mol_uncs, 0) if mol_uncs else None atom_uncs = torch.stack(atom_uncs, 0) if atom_uncs else None bond_uncs = torch.stack(bond_uncs, 0) if bond_uncs else None return (mol_preds, atom_preds, bond_preds), (mol_uncs, atom_uncs, bond_uncs)
def _setup_model(self, model): model.predictor._forward = model.predictor.forward model.predictor.forward = types.MethodType(self._forward.__func__, model.predictor) def _restore_model(self, model): model.predictor.forward = model.predictor._forward del model.predictor._forward def _forward(self, Z: Tensor) -> Tensor: alpha = self.train_step(Z) u = alpha.shape[2] / alpha.sum(-1, keepdim=True) Y = alpha / alpha.sum(-1, keepdim=True) return torch.concat([Y, u], -1)
[docs] @UncertaintyEstimatorRegistry.register("quantile-regression") class QuantileRegressionEstimator(UncertaintyEstimator):
[docs] def __call__( self, dataloader: DataLoader, models: Iterable[MPNN] | Iterable[MolAtomBondMPNN], trainer: pl.Trainer, ) -> ( tuple[Tensor, Tensor] | tuple[ tuple[Tensor | None, Tensor | None, Tensor | None], tuple[Tensor | None, Tensor | None, Tensor | None], ] ): not_mol_atom_bond = isinstance(models[0], MPNN) if not_mol_atom_bond: individual_preds = [] for model in models: predss = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) individual_preds.append(torch.concat(predss, 0)) else: mol_individual_preds = [] atom_individual_preds = [] bond_individual_preds = [] for model in models: MAB_preds = trainer.predict(model, dataloader, **LIGHTNING_26_COMPAT_ARGS) mol_preds, atom_preds, bond_preds = ( torch.concat(preds, 0) if preds[0] is not None else None for preds in zip(*MAB_preds) ) if mol_preds is not None: mol_individual_preds.append(mol_preds) if atom_preds is not None: atom_individual_preds.append(atom_preds) if bond_preds is not None: bond_individual_preds.append(bond_preds) individual_preds_tuple = ( (individual_preds,) if not_mol_atom_bond else (mol_individual_preds, atom_individual_preds, bond_individual_preds) ) means = [] half_intervals = [] for individual_preds in individual_preds_tuple: if individual_preds: stacked_preds = torch.stack(individual_preds).float() mean, interval = stacked_preds.unbind(-1) means.append(mean) half_intervals.append(interval / 2) else: means.append(None) half_intervals.append(None) if not_mol_atom_bond: return means[0], half_intervals[0] return means, half_intervals