chemprop.nn.message_passing
===========================

.. py:module:: chemprop.nn.message_passing


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/chemprop/nn/message_passing/base/index
   /autoapi/chemprop/nn/message_passing/mixins/index
   /autoapi/chemprop/nn/message_passing/mol_atom_bond/index
   /autoapi/chemprop/nn/message_passing/multi/index
   /autoapi/chemprop/nn/message_passing/proto/index


Classes
-------

.. autoapisummary::

   chemprop.nn.message_passing.AtomMessagePassing
   chemprop.nn.message_passing.BondMessagePassing
   chemprop.nn.message_passing.MABAtomMessagePassing
   chemprop.nn.message_passing.MABBondMessagePassing
   chemprop.nn.message_passing.MulticomponentMessagePassing
   chemprop.nn.message_passing.MABMessagePassing
   chemprop.nn.message_passing.MessagePassing


Package Contents
----------------

.. py:class:: AtomMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, V_d_transform = None, graph_transform = None)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._AtomMessagePassingMixin`, :py:obj:`_MessagePassingBase`


   A :class:`AtomMessagePassing` encodes a batch of molecular graphs by passing messages along
   atoms.

   It implements the following operation:

   .. math::

       h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\
       m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\
       h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right)  \right),

   where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and
   :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`;
   :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and
   :math:`t \in \{1, \dots, T\}` is the number of message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_o, W_d** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: BondMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, V_d_transform = None, graph_transform = None)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._BondMessagePassingMixin`, :py:obj:`_MessagePassingBase`


   A :class:`BondMessagePassing` encodes a batch of molecular graphs by passing messages along
   directed bonds.

   It implements the following operation:

   .. math::

       h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\
       m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\
       h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),

   where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and
   :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at
   iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v
   \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of
   message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_o, W_d** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MABAtomMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, d_ed = None, V_d_transform = None, E_d_transform = None, graph_transform = None, return_vertex_embeddings = True, return_edge_embeddings = True)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._AtomMessagePassingMixin`, :py:obj:`_MABMessagePassingBase`


   A :class:`MABAtomMessagePassing` encodes a batch of molecular graphs by passing messages
   along atoms.

   It implements the following operation:

   .. math::

       h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\
       m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\
       h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right)  \right),

   where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and
   :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`;
   :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and
   :math:`t \in \{1, \dots, T\}` is the number of message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, d_ed = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param d_ed: the dimension of additional edge descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_ed: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_vo, W_vd, W_eo, W_ed** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MABBondMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, d_ed = None, V_d_transform = None, E_d_transform = None, graph_transform = None, return_vertex_embeddings = True, return_edge_embeddings = True)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._BondMessagePassingMixin`, :py:obj:`_MABMessagePassingBase`


   A :class:`MABBondMessagePassing` encodes a batch of molecular graphs by passing messages
   along directed bonds.

   It implements the following operation:

   .. math::

       h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\
       m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\
       h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),

   where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and
   :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at
   iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v
   \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of
   message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, d_ed = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param d_ed: the dimension of additional edge descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_ed: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_vo, W_vd, W_eo, W_ed** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MulticomponentMessagePassing(blocks, n_components, shared = False)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A `MulticomponentMessagePassing` performs message-passing on each individual input in a
   multicomponent input then concatenates the representation of each input to construct a
   global representation

   :param blocks: the invidual message-passing blocks for each input
   :type blocks: Sequence[MessagePassing]
   :param n_components: the number of components in each input
   :type n_components: int
   :param shared: whether one block will be shared among all components in an input. If not, a separate
                  block will be learned for each component.
   :type shared: bool, default=False


   .. py:attribute:: hparams


   .. py:attribute:: n_components


   .. py:attribute:: shared
      :value: False



   .. py:attribute:: blocks


   .. py:method:: __len__()


   .. py:property:: output_dim
      :type: int



   .. py:method:: forward(bmgs, V_ds)

      Encode the multicomponent inputs

      :param bmgs:
      :type bmgs: Iterable[BatchMolGraph]
      :param V_ds:
      :type V_ds: Iterable[Tensor | None]

      :returns: a list of tensors of shape `V x d_i` containing the respective encodings of the `i` h
                component, where `d_i` is the output dimension of the `i`   h encoder
      :rtype: list[Tensor]



.. py:class:: MABMessagePassing(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A :class:`MABMessagePassing` module encodes a batch of molecular graphs
   using message passing to learn both vertex-level and edge-level hidden representations.


   .. py:attribute:: output_dims
      :type:  tuple[int | None, int | None]


   .. py:method:: forward(bmg, V_d = None, E_d = None)
      :abstractmethod:


      Encode a batch of molecular graphs.

      :param bmg: the batch of :class:`~chemprop.featurizers.molgraph.MolGraph`\s to encode
      :type bmg: BatchMolGraph
      :param V_d: an optional tensor of shape `V x d_vd` containing additional descriptors for each vertex
                  in the batch. These will be concatenated to the learned vertex descriptors and
                  transformed before the readout phase.
      :type V_d: Tensor | None, default=None
      :param E_d: an optional tensor of shape `E x d_ed` containing additional descriptors for each
                  directed edge in the batch. These will be concatenated to the learned edge descriptors
                  and transformed before the readout phase. NOTE: There are two directed edges per graph
                  connection. If the extra descriptors are for the connections, each row should be
                  repeated twice in the tensor, once for each direction, potentially using
                  ``E_d = np.repeat(E_d, repeats=2, axis=0)``.
      :type E_d: Tensor | None, default=None

      :returns: Two tensors of shape `V x d_h` or `V x (d_h + d_vd)` and `E x dh` or `E x (dh + d_ed)`
                containing the hidden representation of each vertex and edge in the batch of graphs.
                The feature dimension depends on whether additional atom/bond descriptors were provided.
                If either the vertex or edge hidden representations are not needed, computing the
                corresponding tensor can be suppresed by setting either return_vertex_embeddings or
                return_edge_embeddings to `False` when initializing the module.
      :rtype: tuple[Tensor | None, Tensor | None]



.. py:class:: MessagePassing(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A :class:`MessagePassing` module encodes a batch of molecular graphs
   using message passing to learn vertex-level hidden representations.


   .. py:attribute:: output_dim
      :type:  int


   .. py:method:: forward(bmg, V_d = None)
      :abstractmethod:


      Encode a batch of molecular graphs.

      :param bmg: the batch of :class:`~chemprop.featurizers.molgraph.MolGraph`\s to encode
      :type bmg: BatchMolGraph
      :param V_d: an optional tensor of shape `V x d_vd` containing additional descriptors for each vertex
                  in the batch. These will be concatenated to the learned vertex descriptors and
                  transformed before the readout phase.
      :type V_d: Tensor | None, default=None

      :returns: a tensor of shape `V x d_h` or `V x (d_h + d_vd)` containing the hidden representation
                of each vertex in the batch of graphs. The feature dimension depends on whether
                additional vertex descriptors were provided
      :rtype: Tensor



