Source code for chemprop.nn.message_passing.mol_atom_bond

from abc import abstractmethod

from lightning.pytorch.core.mixins import HyperparametersMixin
import torch
from torch import Tensor, nn

from chemprop.conf import DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM, DEFAULT_HIDDEN_DIM
from chemprop.data import BatchMolGraph
from chemprop.exceptions import InvalidShapeError
from chemprop.nn.message_passing.mixins import _AtomMessagePassingMixin, _BondMessagePassingMixin
from chemprop.nn.message_passing.proto import MABMessagePassing
from chemprop.nn.transforms import GraphTransform, ScaleTransform
from chemprop.nn.utils import Activation, get_activation_function


class _MABMessagePassingBase(MABMessagePassing, HyperparametersMixin):
    """The base message-passing block for atom- and bond-based message-passing schemes

    NOTE: this class is an abstract base class and cannot be instantiated

    Parameters
    ----------
    d_v : int, default=DEFAULT_ATOM_FDIM
        the feature dimension of the vertices
    d_e : int, default=DEFAULT_BOND_FDIM
        the feature dimension of the edges
    d_h : int, default=DEFAULT_HIDDEN_DIM
        the hidden dimension during message passing
    bias : bool, defuault=False
        if `True`, add a bias term to the learned weight matrices
    depth : int, default=3
        the number of message passing iterations
    dropout : float, default=0.0
        the dropout probability
    activation : str, default="relu"
        the activation function to use
    undirected : bool, default=False
        if `True`, pass messages on undirected edges
    d_vd : int | None, default=None
        the dimension of additional vertex descriptors that will be concatenated to the hidden
        features before readout
    d_ed : int | None, default=None
        the dimension of additional edge descriptors that will be concatenated to the hidden
        features before readout
    V_d_transform : ScaleTransform | None, default=None
        an optional transformation to apply to the additional vertex descriptors before concatenation
    E_d_transform : ScaleTransform | None, default=None
        an optional transformation to apply to the additional edge descriptors before concatenation
    graph_transform : GraphTransform | None, default=None
        an optional transformation to apply to the :class:`BatchMolGraph` before message passing. It
        is usually used to scale extra vertex and edge features.
    return_vertex_embeddings : bool, default=True
        whether to return the learned vertex embeddings. If `False`, None is returned.
    return_edge_embeddings : bool, default=True
        whether to return the learned edge embeddings. If `False`, None is returned.

    See also
    --------
    * :class:`AtomMessagePassing`

    * :class:`BondMessagePassing`
    """

    def __init__(
        self,
        d_v: int = DEFAULT_ATOM_FDIM,
        d_e: int = DEFAULT_BOND_FDIM,
        d_h: int = DEFAULT_HIDDEN_DIM,
        bias: bool = False,
        depth: int = 3,
        dropout: float = 0.0,
        activation: str | Activation = Activation.RELU,
        undirected: bool = False,
        d_vd: int | None = None,
        d_ed: int | None = None,
        V_d_transform: ScaleTransform | None = None,
        E_d_transform: ScaleTransform | None = None,
        graph_transform: GraphTransform | None = None,
        return_vertex_embeddings: bool = True,
        return_edge_embeddings: bool = True,
    ):
        super().__init__()
        # manually add transforms to hparams to suppress lightning's warning
        # about double saving their state_dict values.
        self.save_hyperparameters(ignore=["V_d_transform", "E_d_transform", "graph_transform"])
        self.hparams["V_d_transform"] = V_d_transform
        self.hparams["E_d_transform"] = E_d_transform
        self.hparams["graph_transform"] = graph_transform
        self.hparams["cls"] = self.__class__

        self.return_vertex_embeddings = return_vertex_embeddings
        self.return_edge_embeddings = return_edge_embeddings

        self.W_i, self.W_h, self.W_vo, self.W_vd, self.W_eo, self.W_ed = self.setup(
            d_v, d_e, d_h, d_vd, d_ed, bias
        )
        self.depth = depth
        self.undirected = undirected
        self.dropout = nn.Dropout(dropout)
        self.tau = get_activation_function(activation)
        self.V_d_transform = V_d_transform if V_d_transform is not None else nn.Identity()
        self.E_d_transform = E_d_transform if E_d_transform is not None else nn.Identity()
        self.graph_transform = graph_transform if graph_transform is not None else nn.Identity()

    @property
    def output_dims(self) -> tuple[int | None, int | None]:
        """Returns the output dimensions of the vertex and edge embeddings."""
        return (
            None
            if not self.return_vertex_embeddings
            else self.W_vd.out_features
            if self.W_vd is not None
            else self.W_vo.out_features,
            None
            if not self.return_edge_embeddings
            else self.W_ed.out_features
            if self.W_ed is not None
            else self.W_eo.out_features,
        )

    @abstractmethod
    def setup(
        self,
        d_v: int = DEFAULT_ATOM_FDIM,
        d_e: int = DEFAULT_BOND_FDIM,
        d_h: int = DEFAULT_HIDDEN_DIM,
        d_vd: int | None = None,
        d_ed: int | None = None,
        bias: bool = False,
    ) -> tuple[
        nn.Module, nn.Module, nn.Module | None, nn.Module | None, nn.Module | None, nn.Module | None
    ]:
        """setup the weight matrices used in the message passing update functions

        Parameters
        ----------
        d_v : int
            the vertex feature dimension
        d_e : int
            the edge feature dimension
        d_h : int, default=300
            the hidden dimension during message passing
        d_vd : int | None, default=None
            the dimension of additional vertex descriptors that will be concatenated to the hidden
            features before readout, if any
        d_ed : int | None, default=None
            the dimension of additional edge descriptors that will be concatenated to the hidden
            features before readout, if any
        bias: bool, default=False
            whether to add a learned bias to the matrices

        Returns
        -------
        W_i, W_h, W_vo, W_vd, W_eo, W_ed : tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
            the input, hidden, output, and descriptor weight matrices, respectively, used in the
            message passing update functions. The descriptor weight matrix is `None` if no vertex
            dimension is supplied
        """

    @abstractmethod
    def initialize(self, bmg: BatchMolGraph) -> Tensor:
        """initialize the message passing scheme by calculating initial matrix of hidden features"""

    @abstractmethod
    def message(self, H_t: Tensor, bmg: BatchMolGraph):
        """Calculate the message matrix"""

    def update(self, M_t, H_0):
        """Calcualte the updated hidden for each edge"""
        H_t = self.W_h(M_t)
        H_t = self.tau(H_0 + H_t)
        H_t = self.dropout(H_t)

        return H_t

    def vertex_finalize(self, M: Tensor, V: Tensor, V_d: Tensor | None) -> Tensor:
        r"""Finalize message passing for vertex embeddings by (1) concatenating the final message
        ``M`` and the original vertex features ``V`` and (2) if provided, further concatenating
        additional vertex descriptors ``V_d``.

        This function implements the following operation:

        .. math::
            H &= \mathtt{dropout} \left( \tau(\mathbf{W}_vo(V \mathbin\Vert M)) \right) \\
            H &= \mathtt{dropout} \left( \tau(\mathbf{W}_vd(H \mathbin\Vert V_d)) \right),

        where :math:`\tau` is the activation function, :math:`\Vert` is the concatenation operator,
        :math:`\mathbf{W}_vo` and :math:`\mathbf{W}_vd` are learned weight matrices, :math:`M` is
        the message matrix, :math:`V` is the original vertex feature matrix, and :math:`V_d` is an
        optional vertex descriptor matrix.

        Parameters
        ----------
        M : Tensor
            a tensor of shape ``V x d_h`` containing the message vector of each vertex
        V : Tensor
            a tensor of shape ``V x d_v`` containing the original vertex features
        V_d : Tensor | None
            an optional tensor of shape ``V x d_vd`` containing additional vertex descriptors

        Returns
        -------
        Tensor
            a tensor of shape ``V x (d_h + d_v [+ d_vd])`` containing the final hidden
            representations

        Raises
        ------
        InvalidShapeError
            if ``V_d`` is not of shape ``b x d_vd``, where ``b`` is the batch size and ``d_vd`` is
            the vertex descriptor dimension
        """
        H = self.W_vo(torch.cat((V, M), dim=1))  # V x d_o
        H = self.tau(H)
        H = self.dropout(H)

        if V_d is not None:
            V_d = self.V_d_transform(V_d)
            try:
                H = self.W_vd(torch.cat((H, V_d), dim=1))  # V x (d_o + d_vd)
                H = self.dropout(H)
            except RuntimeError:
                raise InvalidShapeError(
                    "V_d", V_d.shape, [len(H), self.W_vd.in_features - self.W_vo.out_features]
                )

        return H

    def edge_finalize(self, H: Tensor, E: Tensor, E_d: Tensor | None) -> Tensor:
        r"""Finalize message passing for edge embeddings by (1) concatenating the final hidden
        directed edges ``H`` and the original edge features ``E`` and (2) if provided, further
        concatenating additional edge descriptors ``E_d``.

        This function implements the following operation:

        .. math::
            H &= \mathtt{dropout} \left( \tau(\mathbf{W}_eo(E \mathbin\Vert H)) \right) \\
            H &= \mathtt{dropout} \left( \tau(\mathbf{W}_ed(H \mathbin\Vert E_d)) \right),

        where :math:`\tau` is the activation function, :math:`\Vert` is the concatenation operator,
        :math:`\mathbf{W}_eo` and :math:`\mathbf{W}_ed` are learned weight matrices, :math:`H` is
        the hidden directed edge matrix, :math:`E` is the original edge feature matrix, and
        :math:`E_d` is an optional vertex descriptor matrix.

        Parameters
        ----------
        H : Tensor
            a tensor of shape ``E x d_h`` containing the hidden state for each edge
        E : Tensor
            a tensor of shape ``E x d_e`` containing the original edge features
        E_d : Tensor | None
            an optional tensor of shape ``E x d_ed`` containing additional edge descriptors

        Returns
        -------
        Tensor
            a tensor of shape ``E x (d_h + d_e [+ d_ed])`` containing the final hidden
            representations

        Raises
        ------
        InvalidShapeError
            if ``E_d`` is not of shape ``b x d_ed``, where ``b`` is the batch size and ``d_ed`` is
            the edge descriptor dimension
        """
        H = self.W_eo(torch.cat((E, H), dim=1))
        H = self.tau(H)
        H = self.dropout(H)

        if E_d is not None:
            E_d = self.E_d_transform(E_d)
            try:
                H = self.W_ed(torch.cat((H, E_d), dim=1))
                H = self.dropout(H)
            except RuntimeError:
                # Try E_d = np.repeat(E_d, repeats=2, axis=0), see the note in the docstring of
                #  MABMessagePassing.forward
                raise InvalidShapeError(
                    "E_d", E_d.shape, [len(H), self.W_ed.in_features - self.W_eo.out_features]
                )

        return H

    def forward(
        self, bmg: BatchMolGraph, V_d: Tensor | None = None, E_d: Tensor | None = None
    ) -> tuple[Tensor | None, Tensor | None]:
        bmg = self.graph_transform(bmg)
        H_0 = self.initialize(bmg)

        H = self.tau(H_0)
        for _ in range(1, self.depth):
            if self.undirected:
                H = (H + H[bmg.rev_edge_index]) / 2

            M = self.message(H, bmg)
            H = self.update(M, H_0)

        index_torch = bmg.edge_index[1].unsqueeze(1).repeat(1, H.shape[1])
        M = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
            0, index_torch, H, reduce="sum", include_self=False
        )

        H_v = self.vertex_finalize(M, bmg.V, V_d) if self.return_vertex_embeddings else None
        H_e = self.edge_finalize(H, bmg.E, E_d) if self.return_edge_embeddings else None
        return H_v, H_e


[docs] class MABBondMessagePassing(_BondMessagePassingMixin, _MABMessagePassingBase): r"""A :class:`MABBondMessagePassing` encodes a batch of molecular graphs by passing messages along directed bonds. It implements the following operation: .. math:: h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right), where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`; :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of message passing iterations. """
[docs] def setup( self, d_v: int = DEFAULT_ATOM_FDIM, d_e: int = DEFAULT_BOND_FDIM, d_h: int = DEFAULT_HIDDEN_DIM, d_vd: int | None = None, d_ed: int | None = None, bias: bool = False, ): W_i = nn.Linear(d_v + d_e, d_h, bias) W_h = nn.Linear(d_h, d_h, bias) W_vo = nn.Linear(d_v + d_h, d_h) if self.return_vertex_embeddings else None W_eo = nn.Linear(d_e + d_h, d_h) if self.return_edge_embeddings else None W_vd = nn.Linear(d_h + d_vd, d_h + d_vd) if d_vd else None W_ed = nn.Linear(d_h + d_ed, d_h + d_ed) if d_ed else None return W_i, W_h, W_vo, W_vd, W_eo, W_ed
[docs] class MABAtomMessagePassing(_AtomMessagePassingMixin, _MABMessagePassingBase): r"""A :class:`MABAtomMessagePassing` encodes a batch of molecular graphs by passing messages along atoms. It implements the following operation: .. math:: h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right), where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`; :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`; :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and :math:`t \in \{1, \dots, T\}` is the number of message passing iterations. """
[docs] def setup( self, d_v: int = DEFAULT_ATOM_FDIM, d_e: int = DEFAULT_BOND_FDIM, d_h: int = DEFAULT_HIDDEN_DIM, d_vd: int | None = None, d_ed: int | None = None, bias: bool = False, ): W_i = nn.Linear(d_v, d_h, bias) W_h = nn.Linear(d_e + d_h, d_h, bias) W_vo = nn.Linear(d_v + d_h, d_h) if self.return_vertex_embeddings else None W_eo = nn.Linear(d_e + d_h, d_h) if self.return_edge_embeddings else None W_vd = nn.Linear(d_h + d_vd, d_h + d_vd) if d_vd else None W_ed = nn.Linear(d_h + d_ed, d_h + d_ed) if d_ed else None return W_i, W_h, W_vo, W_vd, W_eo, W_ed