from abc import abstractmethod
from lightning.pytorch.core.mixins import HyperparametersMixin
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from chemprop.conf import DEFAULT_HIDDEN_DIM
from chemprop.nn.ffn import MLP
from chemprop.nn.hparams import HasHParams
from chemprop.nn.metrics import (
MSE,
SID,
BCELoss,
BinaryAUROC,
ChempropMetric,
CrossEntropyLoss,
DirichletLoss,
EvidentialLoss,
MulticlassMCCMetric,
MVELoss,
QuantileLoss,
)
from chemprop.nn.transforms import UnscaleTransform
from chemprop.utils import ClassRegistry, Factory
__all__ = [
"Predictor",
"PredictorRegistry",
"RegressionFFN",
"MveFFN",
"EvidentialFFN",
"BinaryClassificationFFNBase",
"BinaryClassificationFFN",
"BinaryDirichletFFN",
"MulticlassClassificationFFN",
"MulticlassDirichletFFN",
"SpectralFFN",
]
[docs]
class Predictor(nn.Module, HasHParams):
r"""A :class:`Predictor` is a protocol that defines a differentiable function
:math:`f` : \mathbb R^d \mapsto \mathbb R^o"""
input_dim: int
"""the input dimension"""
output_dim: int
"""the output dimension"""
n_tasks: int
"""the number of tasks `t` to predict for each input"""
n_targets: int
"""the number of targets `s` to predict for each task `t`"""
criterion: ChempropMetric
"""the loss function to use for training"""
task_weights: Tensor
"""the weights to apply to each task when calculating the loss"""
output_transform: UnscaleTransform
"""the transform to apply to the output of the predictor"""
[docs]
@abstractmethod
def forward(self, Z: Tensor) -> Tensor:
pass
[docs]
@abstractmethod
def train_step(self, Z: Tensor) -> Tensor:
pass
[docs]
@abstractmethod
def encode(self, Z: Tensor, i: int) -> Tensor:
"""Calculate the :attr:`i`-th hidden representation
Parameters
----------
Z : Tensor
a tensor of shape ``n x d`` containing the input data to encode, where ``d`` is the
input dimensionality.
i : int
The stop index of slice of the MLP used to encode the input. That is, use all
layers in the MLP *up to* :attr:`i` (i.e., ``MLP[:i]``). This can be any integer
value, and the behavior of this function is dependent on the underlying list
slicing behavior. For example:
* ``i=0``: use a 0-layer MLP (i.e., a no-op)
* ``i=1``: use only the first block
* ``i=-1``: use *up to* the final block
Returns
-------
Tensor
a tensor of shape ``n x h`` containing the :attr:`i`-th hidden representation, where
``h`` is the number of neurons in the :attr:`i`-th hidden layer.
"""
pass
PredictorRegistry = ClassRegistry[Predictor]()
class _FFNPredictorBase(Predictor, HyperparametersMixin):
r"""A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
underlying :class:`MLP` to map the learned fingerprint to the desired output.
"""
_T_default_criterion: type[ChempropMetric]
_T_default_metric: type[ChempropMetric]
def __init__(
self,
n_tasks: int = 1,
input_dim: int = DEFAULT_HIDDEN_DIM,
hidden_dim: int = 300,
n_layers: int = 1,
dropout: float = 0.0,
activation: str | nn.Module = "relu",
criterion: ChempropMetric | None = None,
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
):
super().__init__()
# manually add criterion and output_transform to hparams to suppress lightning's warning
# about double saving their state_dict values.
ignore_list = ["criterion", "output_transform", "activation"]
self.save_hyperparameters(ignore=ignore_list)
self.hparams["criterion"] = criterion
self.hparams["output_transform"] = output_transform
self.hparams["activation"] = activation
self.hparams["cls"] = self.__class__
self.ffn = MLP.build(
input_dim, n_tasks * self.n_targets, hidden_dim, n_layers, dropout, activation
)
task_weights = torch.ones(n_tasks) if task_weights is None else task_weights
self.criterion = criterion or Factory.build(
self._T_default_criterion, task_weights=task_weights, threshold=threshold
)
self.output_transform = output_transform if output_transform is not None else nn.Identity()
@property
def input_dim(self) -> int:
return self.ffn.input_dim
@property
def output_dim(self) -> int:
return self.ffn.output_dim
@property
def n_tasks(self) -> int:
return self.output_dim // self.n_targets
def forward(self, Z: Tensor) -> Tensor:
return self.ffn(Z)
def encode(self, Z: Tensor, i: int) -> Tensor:
return self.ffn[:i](Z)
[docs]
@PredictorRegistry.register("regression")
class RegressionFFN(_FFNPredictorBase):
n_targets = 1
_T_default_criterion = MSE
_T_default_metric = MSE
[docs]
def forward(self, Z: Tensor) -> Tensor:
return self.output_transform(self.ffn(Z))
train_step = forward
[docs]
@PredictorRegistry.register("regression-mve")
class MveFFN(RegressionFFN):
n_targets = 2
_T_default_criterion = MVELoss
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = self.ffn(Z)
mean, var = torch.chunk(Y, self.n_targets, 1)
var = F.softplus(var)
mean = self.output_transform(mean)
if not isinstance(self.output_transform, nn.Identity):
var = self.output_transform.transform_variance(var)
return torch.stack((mean, var), dim=2)
train_step = forward
[docs]
@PredictorRegistry.register("regression-evidential")
class EvidentialFFN(RegressionFFN):
n_targets = 4
_T_default_criterion = EvidentialLoss
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = self.ffn(Z)
mean, v, alpha, beta = torch.chunk(Y, self.n_targets, 1)
v = F.softplus(v)
alpha = F.softplus(alpha) + 1
beta = F.softplus(beta)
mean = self.output_transform(mean)
if not isinstance(self.output_transform, nn.Identity):
beta = self.output_transform.transform_variance(beta)
return torch.stack((mean, v, alpha, beta), dim=2)
train_step = forward
[docs]
@PredictorRegistry.register("regression-quantile")
class QuantileFFN(RegressionFFN):
n_targets = 2
_T_default_criterion = QuantileLoss
[docs]
def forward(self, Z: Tensor) -> Tensor:
lower_bound, upper_bound = torch.chunk(self.ffn(Z), self.n_targets, 1)
lower_bound = self.output_transform(lower_bound)
upper_bound = self.output_transform(upper_bound)
mean = (lower_bound + upper_bound) / 2
interval = upper_bound - lower_bound
return torch.stack((mean, interval), dim=2)
train_step = forward
[docs]
class BinaryClassificationFFNBase(_FFNPredictorBase):
pass
[docs]
@PredictorRegistry.register("classification")
class BinaryClassificationFFN(BinaryClassificationFFNBase):
n_targets = 1
_T_default_criterion = BCELoss
_T_default_metric = BinaryAUROC
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = super().forward(Z)
return Y.sigmoid()
[docs]
def train_step(self, Z: Tensor) -> Tensor:
return super().forward(Z)
[docs]
@PredictorRegistry.register("classification-dirichlet")
class BinaryDirichletFFN(BinaryClassificationFFNBase):
n_targets = 2
_T_default_criterion = DirichletLoss
_T_default_metric = BinaryAUROC
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = super().forward(Z).reshape(len(Z), -1, 2)
alpha = F.softplus(Y) + 1
u = 2 / alpha.sum(-1)
Y = alpha / alpha.sum(-1, keepdim=True)
return torch.stack((Y[..., 1], u), dim=2)
[docs]
def train_step(self, Z: Tensor) -> Tensor:
Y = super().forward(Z).reshape(len(Z), -1, 2)
return F.softplus(Y) + 1
[docs]
@PredictorRegistry.register("multiclass")
class MulticlassClassificationFFN(_FFNPredictorBase):
n_targets = 1
_T_default_criterion = CrossEntropyLoss
_T_default_metric = MulticlassMCCMetric
def __init__(
self,
n_classes: int,
n_tasks: int = 1,
input_dim: int = DEFAULT_HIDDEN_DIM,
hidden_dim: int = 300,
n_layers: int = 1,
dropout: float = 0.0,
activation: str | nn.Module = "relu",
criterion: ChempropMetric | None = None,
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
):
task_weights = torch.ones(n_tasks) if task_weights is None else task_weights
super().__init__(
n_tasks * n_classes,
input_dim,
hidden_dim,
n_layers,
dropout,
activation,
criterion,
task_weights,
threshold,
output_transform,
)
self.n_classes = n_classes
@property
def n_tasks(self) -> int:
return self.output_dim // (self.n_targets * self.n_classes)
[docs]
def forward(self, Z: Tensor) -> Tensor:
return super().forward(Z).reshape(Z.shape[0], -1, self.n_classes).softmax(-1)
[docs]
def train_step(self, Z: Tensor) -> Tensor:
return super().forward(Z).reshape(Z.shape[0], -1, self.n_classes)
[docs]
@PredictorRegistry.register("multiclass-dirichlet")
class MulticlassDirichletFFN(MulticlassClassificationFFN):
_T_default_criterion = DirichletLoss
_T_default_metric = MulticlassMCCMetric
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = super().train_step(Z)
alpha = F.softplus(Y) + 1
Y = alpha / alpha.sum(-1, keepdim=True)
return Y
[docs]
def train_step(self, Z: Tensor) -> Tensor:
Y = super().train_step(Z)
return F.softplus(Y) + 1
class _Exp(nn.Module):
def forward(self, X: Tensor):
return X.exp()
[docs]
@PredictorRegistry.register("spectral")
class SpectralFFN(_FFNPredictorBase):
n_targets = 1
_T_default_criterion = SID
_T_default_metric = SID
def __init__(self, *args, spectral_activation: str | None = "softplus", **kwargs):
super().__init__(*args, **kwargs)
match spectral_activation:
case "exp":
spectral_activation = _Exp()
case "softplus" | None:
spectral_activation = nn.Softplus()
case _:
raise ValueError(
f"Unknown spectral activation: {spectral_activation}. "
"Expected one of 'exp', 'softplus' or None."
)
self.ffn.add_module("spectral_activation", spectral_activation)
[docs]
def forward(self, Z: Tensor) -> Tensor:
Y = super().forward(Z)
Y = self.ffn.spectral_activation(Y)
return Y / Y.sum(1, keepdim=True)
train_step = forward