from abc import abstractmethod
from lightning.pytorch.core.mixins import HyperparametersMixin
import torch
from torch import Tensor, nn
from chemprop.conf import DEFAULT_HIDDEN_DIM
from chemprop.nn.hparams import HasHParams
from chemprop.nn.utils import get_activation_function
[docs]
class FFN(nn.Module):
r"""A :class:`FFN` is a differentiable function
:math:`f_\theta : \mathbb R^i \mapsto \mathbb R^o`"""
input_dim: int
output_dim: int
[docs]
@abstractmethod
def forward(self, X: Tensor) -> Tensor:
pass
[docs]
class MLP(nn.Sequential, FFN):
r"""An :class:`MLP` is an FFN that implements the following function:
.. math::
\mathbf h_0 &= \mathbf W_0 \mathbf x \,+ \mathbf b_{0} \\
\mathbf h_l &= \mathbf W_l \left( \mathtt{dropout} \left( \sigma ( \,\mathbf h_{l-1}\, ) \right) \right) + \mathbf b_l\\
where :math:`\mathbf x` is the input tensor, :math:`\mathbf W_l` and :math:`\mathbf b_l`
are the learned weight matrix and bias, respectively, of the :math:`l`-th layer,
:math:`\mathbf h_l` is the hidden representation after layer :math:`l`, and :math:`\sigma`
is the activation function.
"""
[docs]
@classmethod
def build(
cls,
input_dim: int,
output_dim: int,
hidden_dim: int = 300,
n_layers: int = 1,
dropout: float = 0.0,
activation: str | nn.Module = "relu",
):
dropout = nn.Dropout(dropout)
act = get_activation_function(activation)
dims = [input_dim] + [hidden_dim] * n_layers + [output_dim]
blocks = [nn.Sequential(nn.Linear(dims[0], dims[1]))]
if len(dims) > 2:
blocks.extend(
[
nn.Sequential(act, dropout, nn.Linear(d1, d2))
for d1, d2 in zip(dims[1:-1], dims[2:])
]
)
return cls(*blocks)
@property
def input_dim(self) -> int:
return self[0][-1].in_features
@property
def output_dim(self) -> int:
return self[-1][-1].out_features
[docs]
class ConstrainerFFN(nn.Module, HasHParams, HyperparametersMixin):
"""A :class:`ConstrainerFFN` adjusts atom or bond property predictions to satisfy molecular
constraints by using an :class:`MLP` to map learned atom or bond embeddings to weights that
determine how much of the total adjustment needed is added to each atom or bond prediction.
"""
def __init__(
self,
n_constraints: int = 1,
fp_dim: int = DEFAULT_HIDDEN_DIM,
hidden_dim: int = 300,
n_layers: int = 1,
dropout: float = 0.0,
activation: str = "relu",
):
super().__init__()
self.save_hyperparameters()
self.hparams["cls"] = self.__class__
self.ffn = MLP.build(fp_dim, n_constraints, hidden_dim, n_layers, dropout, activation)
[docs]
def forward(self, fp: Tensor, preds: Tensor, batch: Tensor, constraints: Tensor) -> Tensor:
"""Performs a weighted adjustment to the predictions to satisfy the constraints, with the
weights being determined from the learned atom or bond fingerprints via an :class:`MLP`.
Parameters
----------
fp : Tensor
a tensor of shape ``b x h`` containing the atom or bond-level fingerprints, where ``b``
is the number of atoms or bonds and ``h`` is the length of each fingerprint.
preds : Tensor
a tensor of shape ``b x t`` containing the atom or bond-level predictions, where ``t``
is the number of predictions per atom or bond.
batch : Tensor
a tensor of shape ``b`` containing indices of which molecule each atom or bond belongs to
constraints : Tensor
a tensor of shape ``m x t`` containing the values to which the atom or bond-level
predictions should sum to for each molecule, where ``m`` is the number of molecules in
the batch.
Returns
-------
Tensor
a tensor of shape ``b x t`` containing the atom or bond-level predictions adjusted to
satisfy the molecule-level constraints
"""
k = self.ffn(fp)
expk = k.exp()
n_mols = constraints.shape[0]
index_torch = batch.unsqueeze(1).repeat(1, k.shape[1])
per_mol_sum_expk = torch.zeros(
n_mols, expk.shape[1], dtype=expk.dtype, device=expk.device
).scatter_reduce_(0, index_torch, expk, reduce="sum", include_self=False)
by_atom_or_bond_sum_expk = per_mol_sum_expk[batch]
w = expk / (by_atom_or_bond_sum_expk)
index_torch = batch.unsqueeze(1).repeat(1, preds.shape[1])
per_mol_preds = torch.zeros(
n_mols, preds.shape[1], dtype=preds.dtype, device=preds.device
).scatter_reduce_(0, index_torch, preds, reduce="sum", include_self=False)
pred_has_constraint = ~torch.isnan(constraints)[0]
deviation = constraints[:, pred_has_constraint] - per_mol_preds[:, pred_has_constraint]
corrections = w * deviation[batch]
cor_shape_preds = torch.zeros_like(preds, dtype=preds.dtype, device=preds.device)
cor_shape_preds[:, pred_has_constraint] = corrections
return preds + cor_shape_preds