chemprop.nn
===========

.. py:module:: chemprop.nn


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/chemprop/nn/agg/index
   /autoapi/chemprop/nn/ffn/index
   /autoapi/chemprop/nn/hparams/index
   /autoapi/chemprop/nn/message_passing/index
   /autoapi/chemprop/nn/metrics/index
   /autoapi/chemprop/nn/predictors/index
   /autoapi/chemprop/nn/transforms/index
   /autoapi/chemprop/nn/utils/index


Attributes
----------

.. autoapisummary::

   chemprop.nn.AggregationRegistry
   chemprop.nn.LossFunctionRegistry
   chemprop.nn.MetricRegistry
   chemprop.nn.PredictorRegistry


Classes
-------

.. autoapisummary::

   chemprop.nn.Aggregation
   chemprop.nn.AttentiveAggregation
   chemprop.nn.MeanAggregation
   chemprop.nn.NormAggregation
   chemprop.nn.SumAggregation
   chemprop.nn.ConstrainerFFN
   chemprop.nn.AtomMessagePassing
   chemprop.nn.BondMessagePassing
   chemprop.nn.MABAtomMessagePassing
   chemprop.nn.MABBondMessagePassing
   chemprop.nn.MABMessagePassing
   chemprop.nn.MessagePassing
   chemprop.nn.MulticomponentMessagePassing
   chemprop.nn.MAE
   chemprop.nn.MSE
   chemprop.nn.RMSE
   chemprop.nn.SID
   chemprop.nn.BCELoss
   chemprop.nn.BinaryAccuracy
   chemprop.nn.BinaryAUPRC
   chemprop.nn.BinaryAUROC
   chemprop.nn.BinaryF1Score
   chemprop.nn.BinaryMCCLoss
   chemprop.nn.BinaryMCCMetric
   chemprop.nn.BoundedMAE
   chemprop.nn.BoundedMixin
   chemprop.nn.BoundedMSE
   chemprop.nn.BoundedRMSE
   chemprop.nn.ChempropMetric
   chemprop.nn.ClassificationMixin
   chemprop.nn.CrossEntropyLoss
   chemprop.nn.DirichletLoss
   chemprop.nn.EvidentialLoss
   chemprop.nn.MulticlassMCCLoss
   chemprop.nn.MulticlassMCCMetric
   chemprop.nn.MVELoss
   chemprop.nn.QuantileLoss
   chemprop.nn.R2Score
   chemprop.nn.Wasserstein
   chemprop.nn.BinaryClassificationFFN
   chemprop.nn.BinaryClassificationFFNBase
   chemprop.nn.BinaryDirichletFFN
   chemprop.nn.EvidentialFFN
   chemprop.nn.MulticlassClassificationFFN
   chemprop.nn.MveFFN
   chemprop.nn.Predictor
   chemprop.nn.QuantileFFN
   chemprop.nn.RegressionFFN
   chemprop.nn.SpectralFFN
   chemprop.nn.GraphTransform
   chemprop.nn.ScaleTransform
   chemprop.nn.UnscaleTransform
   chemprop.nn.Activation


Package Contents
----------------

.. py:class:: Aggregation(dim = 0, *args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   An :class:`Aggregation` aggregates the node-level representations of a batch of graphs into
   a batch of graph-level representations

   .. note::
       this class is abstract and cannot be instantiated.

   .. seealso:: :class:`~chemprop.v2.models.modules.agg.MeanAggregation`, :class:`~chemprop.v2.models.modules.agg.SumAggregation`, :class:`~chemprop.v2.models.modules.agg.NormAggregation`


   .. py:attribute:: dim
      :value: 0



   .. py:attribute:: hparams


   .. py:method:: forward(H, batch)
      :abstractmethod:


      Aggregate the graph-level representations of a batch of graphs into their respective
      global representations

      NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
      a zero vector of length `d` in the final output.

      :param H: a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
                graphs
      :type H: Tensor
      :param batch: a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
      :type batch: Tensor

      :returns: a tensor of shape ``b x d`` containing the graph-level representations
      :rtype: Tensor



.. py:data:: AggregationRegistry

.. py:class:: AttentiveAggregation(dim = 0, *args, output_size, **kwargs)

   Bases: :py:obj:`Aggregation`


   An :class:`Aggregation` aggregates the node-level representations of a batch of graphs into
   a batch of graph-level representations

   .. note::
       this class is abstract and cannot be instantiated.

   .. seealso:: :class:`~chemprop.v2.models.modules.agg.MeanAggregation`, :class:`~chemprop.v2.models.modules.agg.SumAggregation`, :class:`~chemprop.v2.models.modules.agg.NormAggregation`


   .. py:attribute:: W


   .. py:method:: forward(H, batch)

      Aggregate the graph-level representations of a batch of graphs into their respective
      global representations

      NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
      a zero vector of length `d` in the final output.

      :param H: a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
                graphs
      :type H: Tensor
      :param batch: a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
      :type batch: Tensor

      :returns: a tensor of shape ``b x d`` containing the graph-level representations
      :rtype: Tensor



.. py:class:: MeanAggregation(dim = 0, *args, **kwargs)

   Bases: :py:obj:`Aggregation`


   Average the graph-level representation:

   .. math::
       \mathbf h = \frac{1}{|V|} \sum_{v \in V} \mathbf h_v


   .. py:method:: forward(H, batch)

      Aggregate the graph-level representations of a batch of graphs into their respective
      global representations

      NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
      a zero vector of length `d` in the final output.

      :param H: a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
                graphs
      :type H: Tensor
      :param batch: a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
      :type batch: Tensor

      :returns: a tensor of shape ``b x d`` containing the graph-level representations
      :rtype: Tensor



.. py:class:: NormAggregation(dim = 0, *args, norm = 100.0, **kwargs)

   Bases: :py:obj:`SumAggregation`


   Sum the graph-level representation and divide by a normalization constant:

   .. math::
       \mathbf h = \frac{1}{c} \sum_{v \in V} \mathbf h_v


   .. py:attribute:: norm
      :value: 100.0



   .. py:method:: forward(H, batch)

      Aggregate the graph-level representations of a batch of graphs into their respective
      global representations

      NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
      a zero vector of length `d` in the final output.

      :param H: a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
                graphs
      :type H: Tensor
      :param batch: a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
      :type batch: Tensor

      :returns: a tensor of shape ``b x d`` containing the graph-level representations
      :rtype: Tensor



.. py:class:: SumAggregation(dim = 0, *args, **kwargs)

   Bases: :py:obj:`Aggregation`


   Sum the graph-level representation:

   .. math::
       \mathbf h = \sum_{v \in V} \mathbf h_v



   .. py:method:: forward(H, batch)

      Aggregate the graph-level representations of a batch of graphs into their respective
      global representations

      NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
      a zero vector of length `d` in the final output.

      :param H: a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
                graphs
      :type H: Tensor
      :param batch: a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
      :type batch: Tensor

      :returns: a tensor of shape ``b x d`` containing the graph-level representations
      :rtype: Tensor



.. py:class:: ConstrainerFFN(n_constraints = 1, fp_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu')

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`, :py:obj:`lightning.pytorch.core.mixins.HyperparametersMixin`


   A :class:`ConstrainerFFN` adjusts atom or bond property predictions to satisfy molecular
   constraints by using an :class:`MLP` to map learned atom or bond embeddings to weights that
   determine how much of the total adjustment needed is added to each atom or bond prediction.


   .. py:attribute:: ffn


   .. py:method:: forward(fp, preds, batch, constraints)

      Performs a weighted adjustment to the predictions to satisfy the constraints, with the
      weights being determined from the learned atom or bond fingerprints via an :class:`MLP`.

      :param fp: a tensor of shape ``b x h`` containing the atom or bond-level fingerprints, where ``b``
                 is the number of atoms or bonds and ``h`` is the length of each fingerprint.
      :type fp: Tensor
      :param preds: a tensor of shape ``b x t`` containing the atom or bond-level predictions, where ``t``
                    is the number of predictions per atom or bond.
      :type preds: Tensor
      :param batch: a tensor of shape ``b`` containing indices of which molecule each atom or bond belongs to
      :type batch: Tensor
      :param constraints: a tensor of shape ``m x t`` containing the values to which the atom or bond-level
                          predictions should sum to for each molecule, where ``m`` is the number of molecules in
                          the batch.
      :type constraints: Tensor

      :returns: a tensor of shape ``b x t`` containing the atom or bond-level predictions adjusted to
                satisfy the molecule-level constraints
      :rtype: Tensor



.. py:class:: AtomMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, V_d_transform = None, graph_transform = None)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._AtomMessagePassingMixin`, :py:obj:`_MessagePassingBase`


   A :class:`AtomMessagePassing` encodes a batch of molecular graphs by passing messages along
   atoms.

   It implements the following operation:

   .. math::

       h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\
       m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\
       h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right)  \right),

   where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and
   :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`;
   :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and
   :math:`t \in \{1, \dots, T\}` is the number of message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_o, W_d** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: BondMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, V_d_transform = None, graph_transform = None)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._BondMessagePassingMixin`, :py:obj:`_MessagePassingBase`


   A :class:`BondMessagePassing` encodes a batch of molecular graphs by passing messages along
   directed bonds.

   It implements the following operation:

   .. math::

       h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\
       m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\
       h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),

   where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and
   :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at
   iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v
   \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of
   message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_o, W_d** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MABAtomMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, d_ed = None, V_d_transform = None, E_d_transform = None, graph_transform = None, return_vertex_embeddings = True, return_edge_embeddings = True)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._AtomMessagePassingMixin`, :py:obj:`_MABMessagePassingBase`


   A :class:`MABAtomMessagePassing` encodes a batch of molecular graphs by passing messages
   along atoms.

   It implements the following operation:

   .. math::

       h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\
       m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\
       h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right)  \right),

   where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and
   :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`;
   :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and
   :math:`t \in \{1, \dots, T\}` is the number of message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, d_ed = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param d_ed: the dimension of additional edge descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_ed: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_vo, W_vd, W_eo, W_ed** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MABBondMessagePassing(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, bias = False, depth = 3, dropout = 0.0, activation = Activation.RELU, undirected = False, d_vd = None, d_ed = None, V_d_transform = None, E_d_transform = None, graph_transform = None, return_vertex_embeddings = True, return_edge_embeddings = True)

   Bases: :py:obj:`chemprop.nn.message_passing.mixins._BondMessagePassingMixin`, :py:obj:`_MABMessagePassingBase`


   A :class:`MABBondMessagePassing` encodes a batch of molecular graphs by passing messages
   along directed bonds.

   It implements the following operation:

   .. math::

       h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\
       m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\
       h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\
       m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\
       h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),

   where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and
   :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the
   bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`;
   :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at
   iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v
   \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of
   message passing iterations.


   .. py:method:: setup(d_v = DEFAULT_ATOM_FDIM, d_e = DEFAULT_BOND_FDIM, d_h = DEFAULT_HIDDEN_DIM, d_vd = None, d_ed = None, bias = False)

      setup the weight matrices used in the message passing update functions

      :param d_v: the vertex feature dimension
      :type d_v: int
      :param d_e: the edge feature dimension
      :type d_e: int
      :param d_h: the hidden dimension during message passing
      :type d_h: int, default=300
      :param d_vd: the dimension of additional vertex descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_vd: int | None, default=None
      :param d_ed: the dimension of additional edge descriptors that will be concatenated to the hidden
                   features before readout, if any
      :type d_ed: int | None, default=None
      :param bias: whether to add a learned bias to the matrices
      :type bias: bool, default=False

      :returns: **W_i, W_h, W_vo, W_vd, W_eo, W_ed** -- the input, hidden, output, and descriptor weight matrices, respectively, used in the
                message passing update functions. The descriptor weight matrix is `None` if no vertex
                dimension is supplied
      :rtype: tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]



.. py:class:: MABMessagePassing(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A :class:`MABMessagePassing` module encodes a batch of molecular graphs
   using message passing to learn both vertex-level and edge-level hidden representations.


   .. py:attribute:: output_dims
      :type:  tuple[int | None, int | None]


   .. py:method:: forward(bmg, V_d = None, E_d = None)
      :abstractmethod:


      Encode a batch of molecular graphs.

      :param bmg: the batch of :class:`~chemprop.featurizers.molgraph.MolGraph`\s to encode
      :type bmg: BatchMolGraph
      :param V_d: an optional tensor of shape `V x d_vd` containing additional descriptors for each vertex
                  in the batch. These will be concatenated to the learned vertex descriptors and
                  transformed before the readout phase.
      :type V_d: Tensor | None, default=None
      :param E_d: an optional tensor of shape `E x d_ed` containing additional descriptors for each
                  directed edge in the batch. These will be concatenated to the learned edge descriptors
                  and transformed before the readout phase. NOTE: There are two directed edges per graph
                  connection. If the extra descriptors are for the connections, each row should be
                  repeated twice in the tensor, once for each direction, potentially using
                  ``E_d = np.repeat(E_d, repeats=2, axis=0)``.
      :type E_d: Tensor | None, default=None

      :returns: Two tensors of shape `V x d_h` or `V x (d_h + d_vd)` and `E x dh` or `E x (dh + d_ed)`
                containing the hidden representation of each vertex and edge in the batch of graphs.
                The feature dimension depends on whether additional atom/bond descriptors were provided.
                If either the vertex or edge hidden representations are not needed, computing the
                corresponding tensor can be suppresed by setting either return_vertex_embeddings or
                return_edge_embeddings to `False` when initializing the module.
      :rtype: tuple[Tensor | None, Tensor | None]



.. py:class:: MessagePassing(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A :class:`MessagePassing` module encodes a batch of molecular graphs
   using message passing to learn vertex-level hidden representations.


   .. py:attribute:: output_dim
      :type:  int


   .. py:method:: forward(bmg, V_d = None)
      :abstractmethod:


      Encode a batch of molecular graphs.

      :param bmg: the batch of :class:`~chemprop.featurizers.molgraph.MolGraph`\s to encode
      :type bmg: BatchMolGraph
      :param V_d: an optional tensor of shape `V x d_vd` containing additional descriptors for each vertex
                  in the batch. These will be concatenated to the learned vertex descriptors and
                  transformed before the readout phase.
      :type V_d: Tensor | None, default=None

      :returns: a tensor of shape `V x d_h` or `V x (d_h + d_vd)` containing the hidden representation
                of each vertex in the batch of graphs. The feature dimension depends on whether
                additional vertex descriptors were provided
      :rtype: Tensor



.. py:class:: MulticomponentMessagePassing(blocks, n_components, shared = False)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A `MulticomponentMessagePassing` performs message-passing on each individual input in a
   multicomponent input then concatenates the representation of each input to construct a
   global representation

   :param blocks: the invidual message-passing blocks for each input
   :type blocks: Sequence[MessagePassing]
   :param n_components: the number of components in each input
   :type n_components: int
   :param shared: whether one block will be shared among all components in an input. If not, a separate
                  block will be learned for each component.
   :type shared: bool, default=False


   .. py:attribute:: hparams


   .. py:attribute:: n_components


   .. py:attribute:: shared
      :value: False



   .. py:attribute:: blocks


   .. py:method:: __len__()


   .. py:property:: output_dim
      :type: int



   .. py:method:: forward(bmgs, V_ds)

      Encode the multicomponent inputs

      :param bmgs:
      :type bmgs: Iterable[BatchMolGraph]
      :param V_ds:
      :type V_ds: Iterable[Tensor | None]

      :returns: a list of tensors of shape `V x d_i` containing the respective encodings of the `i` h
                component, where `d_i` is the output dimension of the `i`   h encoder
      :rtype: list[Tensor]



.. py:class:: MAE(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: MSE(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: RMSE(task_weights = 1.0)

   Bases: :py:obj:`MSE`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:method:: compute()


.. py:class:: SID(task_weights = 1.0, threshold = None, **kwargs)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:attribute:: threshold
      :value: None



   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:class:: BCELoss(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: BinaryAccuracy(task_weights = 1.0, **kwargs)

   Bases: :py:obj:`ClassificationMixin`, :py:obj:`torchmetrics.classification.BinaryAccuracy`


   Compute `Accuracy`_ for binary tasks.

   .. math::
       \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)

   Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

   As input to ``forward`` and ``update`` the metric accepts the following input:

       - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating
         point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
         per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
       - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``

   As output to ``forward`` and ``compute`` the metric returns the following output:

       - ``acc`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value.
         If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
         value per sample.

   If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
   which the reduction will then be applied over instead of the sample dimension ``N``.

   :param threshold: Threshold for transforming probability to binary {0,1} predictions
   :param multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following:

                            - ``global``: Additional dimensions are flatted along the batch dimension
                            - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
                              The statistics in this case are calculated over the additional dimensions.
   :param ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
   :param validate_args: bool indicating if input arguments and tensors should be validated for correctness.
                         Set to ``False`` for faster computations.

   Example (preds is int tensor):
       >>> from torch import tensor
       >>> from torchmetrics.classification import BinaryAccuracy
       >>> target = tensor([0, 1, 0, 1, 0, 1])
       >>> preds = tensor([0, 0, 1, 1, 0, 1])
       >>> metric = BinaryAccuracy()
       >>> metric(preds, target)
       tensor(0.6667)

   Example (preds is float tensor):
       >>> from torchmetrics.classification import BinaryAccuracy
       >>> target = tensor([0, 1, 0, 1, 0, 1])
       >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
       >>> metric = BinaryAccuracy()
       >>> metric(preds, target)
       tensor(0.6667)

   Example (multidim tensors):
       >>> from torchmetrics.classification import BinaryAccuracy
       >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
       >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
       ...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
       >>> metric = BinaryAccuracy(multidim_average='samplewise')
       >>> metric(preds, target)
       tensor([0.3333, 0.1667])



.. py:class:: BinaryAUPRC(task_weights = 1.0, **kwargs)

   Bases: :py:obj:`ClassificationMixin`, :py:obj:`torchmetrics.classification.BinaryPrecisionRecallCurve`


   Compute the precision-recall curve for binary tasks.

   The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the
   tradeoff between the two values can been seen.

   As input to ``forward`` and ``update`` the metric accepts the following input:

   - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing
     probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input
     to be logits and will auto apply sigmoid per element.
   - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
     ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
     1 always encodes the positive class.

   .. tip::
      Additional dimension ``...`` will be flattened into the batch dimension.

   As output to ``forward`` and ``compute`` the metric returns the following output:

   - ``precision`` (:class:`~torch.Tensor`): if `thresholds=None` a list for each class is returned with an 1d
     tensor of size ``(n_thresholds+1, )`` with precision values (length may differ between classes). If `thresholds`
     is set to something else, then a single 2d tensor of size ``(n_classes, n_thresholds+1)`` with precision values
     is returned.
   - ``recall`` (:class:`~torch.Tensor`): if `thresholds=None` a list for each class is returned with an 1d tensor
     of size ``(n_thresholds+1, )`` with recall values (length may differ between classes). If `thresholds` is set to
     something else, then a single 2d tensor of size ``(n_classes, n_thresholds+1)`` with recall values is returned.
   - ``thresholds`` (:class:`~torch.Tensor`): if `thresholds=None` a list for each class is returned with an 1d
     tensor of size ``(n_thresholds, )`` with increasing threshold values (length may differ between classes). If
     `threshold` is set to something else, then a single 1d tensor of size ``(n_thresholds, )`` is returned with
     shared threshold values for all classes.

   .. note::
      The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
      that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
      non-binned  version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
      argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
      size :math:`\mathcal{O}(n_{thresholds})` (constant memory).

   :param thresholds: Can be one of:

                      - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
                        all the data. Most accurate but also most memory consuming approach.
                      - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
                        0 to 1 as bins for the calculation.
                      - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
                      - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
                        bins for the calculation.
   :param ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
   :param validate_args: bool indicating if input arguments and tensors should be validated for correctness.
                         Set to ``False`` for faster computations.
   :param normalization: Specifies a normalization method that is used for batch-wise update regarding negative logits.
                         Set to ``None`` if negative logits are desired in evaluation.
   :param kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

   .. rubric:: Example

   >>> from torchmetrics.classification import BinaryPrecisionRecallCurve
   >>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
   >>> target = torch.tensor([0, 1, 1, 0])
   >>> bprc = BinaryPrecisionRecallCurve(thresholds=None)
   >>> bprc(preds, target)  # doctest: +NORMALIZE_WHITESPACE
   (tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
    tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
    tensor([0.0000, 0.5000, 0.7000, 0.8000]))
   >>> bprc = BinaryPrecisionRecallCurve(thresholds=5)
   >>> bprc(preds, target)  # doctest: +NORMALIZE_WHITESPACE
   (tensor([0.5000, 0.6667, 0.6667, 0.0000,    nan, 1.0000]),
    tensor([1., 1., 1., 0., 0., 0.]),
    tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))


   .. py:method:: compute()

      Compute metric.



.. py:class:: BinaryAUROC(task_weights = 1.0, **kwargs)

   Bases: :py:obj:`ClassificationMixin`, :py:obj:`torchmetrics.classification.BinaryAUROC`


   Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks.

   The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
   multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
   corresponds to random guessing.

   As input to ``forward`` and ``update`` the metric accepts the following input:

   - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
     each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
     sigmoid per element.
   - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
     therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
     positive class.

   As output to ``forward`` and ``compute`` the metric returns the following output:

   - ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score.

   Additional dimension ``...`` will be flattened into the batch dimension.

   The implementation both supports calculating the metric in a non-binned but accurate version and a
   binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will
   activate the non-binned  version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
   `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
   size :math:`\mathcal{O}(n_{thresholds})` (constant memory).

   :param max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``.
   :param thresholds: Can be one of:

                      - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
                        all the data. Most accurate but also most memory consuming approach.
                      - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
                        0 to 1 as bins for the calculation.
                      - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
                      - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
                        bins for the calculation.
   :param validate_args: bool indicating if input arguments and tensors should be validated for correctness.
                         Set to ``False`` for faster computations.
   :param kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

   .. rubric:: Example

   >>> from torch import tensor
   >>> from torchmetrics.classification import BinaryAUROC
   >>> preds = tensor([0, 0.5, 0.7, 0.8])
   >>> target = tensor([0, 1, 1, 0])
   >>> metric = BinaryAUROC(thresholds=None)
   >>> metric(preds, target)
   tensor(0.5000)
   >>> b_auroc = BinaryAUROC(thresholds=5)
   >>> b_auroc(preds, target)
   tensor(0.5000)


.. py:class:: BinaryF1Score(task_weights = 1.0, **kwargs)

   Bases: :py:obj:`ClassificationMixin`, :py:obj:`torchmetrics.classification.BinaryF1Score`


   Compute F-1 score for binary tasks.

   .. math::
       F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}

   The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
   where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
   positives and false negatives respectively. If this case is encountered a score of `zero_division`
   (0 or 1, default is 0) is returned.

   As input to ``forward`` and ``update`` the metric accepts the following input:

   - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
     tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
     element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
   - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``

   As output to ``forward`` and ``compute`` the metric returns the following output:

   - ``bf1s`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:

       - If ``multidim_average`` is set to ``global``, the metric returns a scalar value.
       - If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
         value per sample.

   If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
   which the reduction will then be applied over instead of the sample dimension ``N``.

   :param threshold: Threshold for transforming probability to binary {0,1} predictions
   :param multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following:

                            - ``global``: Additional dimensions are flatted along the batch dimension
                            - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
                              The statistics in this case are calculated over the additional dimensions.
   :param ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
   :param validate_args: bool indicating if input arguments and tensors should be validated for correctness.
                         Set to ``False`` for faster computations.
   :param zero_division: Should be `0` or `1`. The value returned when
                         :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.

   Example (preds is int tensor):
       >>> from torch import tensor
       >>> from torchmetrics.classification import BinaryF1Score
       >>> target = tensor([0, 1, 0, 1, 0, 1])
       >>> preds = tensor([0, 0, 1, 1, 0, 1])
       >>> metric = BinaryF1Score()
       >>> metric(preds, target)
       tensor(0.6667)

   Example (preds is float tensor):
       >>> from torchmetrics.classification import BinaryF1Score
       >>> target = tensor([0, 1, 0, 1, 0, 1])
       >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
       >>> metric = BinaryF1Score()
       >>> metric(preds, target)
       tensor(0.6667)

   Example (multidim tensors):
       >>> from torchmetrics.classification import BinaryF1Score
       >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
       >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
       ...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
       >>> metric = BinaryF1Score(multidim_average='samplewise')
       >>> metric(preds, target)
       tensor([0.5000, 0.0000])



.. py:class:: BinaryMCCLoss(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:method:: update(preds, targets, mask = None, weights = None, *args)

      Calculate the mean loss function value given predicted and target values

      :param preds: a tensor of shape `b x t x u` (regression with uncertainty), `b x t` (regression without
                    uncertainty and binary classification, except for binary dirichlet), or `b x t x c`
                    (multiclass classification and binary dirichlet) containing the predictions, where `b`
                    is the batch size, `t` is the number of tasks to predict, `u` is the number of values to
                    predict for each task, and `c` is the number of classes.
      :type preds: Tensor
      :param targets: a float tensor of shape `b x t` containing the target values
      :type targets: Tensor
      :param mask: a boolean tensor of shape `b x t` indicating whether the given prediction should be
                   included in the loss calculation
      :type mask: Tensor
      :param weights: a tensor of shape `b` or `b x 1` containing the per-sample weight
      :type weights: Tensor
      :param lt_mask:
      :type lt_mask: Tensor
      :param gt_mask:
      :type gt_mask: Tensor



   .. py:method:: compute()


.. py:class:: BinaryMCCMetric(task_weights = 1.0)

   Bases: :py:obj:`BinaryMCCLoss`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:attribute:: higher_is_better
      :value: True



   .. py:method:: compute()


.. py:class:: BoundedMAE(task_weights = 1.0)

   Bases: :py:obj:`BoundedMixin`, :py:obj:`MAE`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: BoundedMixin

.. py:class:: BoundedMSE(task_weights = 1.0)

   Bases: :py:obj:`BoundedMixin`, :py:obj:`MSE`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: BoundedRMSE(task_weights = 1.0)

   Bases: :py:obj:`BoundedMixin`, :py:obj:`RMSE`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: ChempropMetric(task_weights = 1.0)

   Bases: :py:obj:`torchmetrics.Metric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:attribute:: is_differentiable
      :value: True



   .. py:attribute:: higher_is_better
      :value: False



   .. py:attribute:: full_state_update
      :value: False



   .. py:method:: update(preds, targets, mask = None, weights = None, lt_mask = None, gt_mask = None)

      Calculate the mean loss function value given predicted and target values

      :param preds: a tensor of shape `b x t x u` (regression with uncertainty), `b x t` (regression without
                    uncertainty and binary classification, except for binary dirichlet), or `b x t x c`
                    (multiclass classification and binary dirichlet) containing the predictions, where `b`
                    is the batch size, `t` is the number of tasks to predict, `u` is the number of values to
                    predict for each task, and `c` is the number of classes.
      :type preds: Tensor
      :param targets: a float tensor of shape `b x t` containing the target values
      :type targets: Tensor
      :param mask: a boolean tensor of shape `b x t` indicating whether the given prediction should be
                   included in the loss calculation
      :type mask: Tensor
      :param weights: a tensor of shape `b` or `b x 1` containing the per-sample weight
      :type weights: Tensor
      :param lt_mask:
      :type lt_mask: Tensor
      :param gt_mask:
      :type gt_mask: Tensor



   .. py:method:: compute()


   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:class:: ClassificationMixin(task_weights = 1.0, **kwargs)

   .. py:method:: update(preds, targets, mask, *args, **kwargs)


.. py:class:: CrossEntropyLoss(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


.. py:class:: DirichletLoss(task_weights = 1.0, v_kl = 0.2)

   Bases: :py:obj:`ChempropMetric`


   Uses the loss function from [sensoy2018]_ based on the implementation at [sensoyGithub]_

   .. rubric:: References

   .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify
       classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768
   .. [sensoyGithub] https://muratsensoy.github.io/uncertainty.html#Define-the-loss-function


   .. py:attribute:: v_kl
      :value: 0.2



   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:class:: EvidentialLoss(task_weights = 1.0, v_kl = 0.2, eps = 1e-08)

   Bases: :py:obj:`ChempropMetric`


   Calculate the loss using Eqs. 8, 9, and 10 from [amini2020]_. See also [soleimany2021]_.

   .. rubric:: References

   .. [amini2020] Amini, A; Schwarting, W.; Soleimany, A.; Rus, D.;
       "Deep Evidential Regression" Advances in Neural Information Processing Systems; 2020; Vol.33.
       https://proceedings.neurips.cc/paper_files/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf
   .. [soleimany2021] Soleimany, A.P.; Amini, A.; Goldman, S.; Rus, D.; Bhatia, S.N.; Coley, C.W.;
       "Evidential Deep Learning for Guided Molecular Property Prediction and Discovery." ACS
       Cent. Sci. 2021, 7, 8, 1356-1367. https://doi.org/10.1021/acscentsci.1c00546


   .. py:attribute:: v_kl
      :value: 0.2



   .. py:attribute:: eps
      :value: 1e-08



   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:data:: LossFunctionRegistry

.. py:data:: MetricRegistry

.. py:class:: MulticlassMCCLoss(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Calculate a soft Matthews correlation coefficient ([mccWiki]_) loss for multiclass
   classification based on the implementataion of [mccSklearn]_
   .. rubric:: References

   .. [mccWiki] https://en.wikipedia.org/wiki/Phi_coefficient#Multiclass_case
   .. [mccSklearn] https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html


   .. py:method:: update(preds, targets, mask = None, weights = None, *args)

      Calculate the mean loss function value given predicted and target values

      :param preds: a tensor of shape `b x t x u` (regression with uncertainty), `b x t` (regression without
                    uncertainty and binary classification, except for binary dirichlet), or `b x t x c`
                    (multiclass classification and binary dirichlet) containing the predictions, where `b`
                    is the batch size, `t` is the number of tasks to predict, `u` is the number of values to
                    predict for each task, and `c` is the number of classes.
      :type preds: Tensor
      :param targets: a float tensor of shape `b x t` containing the target values
      :type targets: Tensor
      :param mask: a boolean tensor of shape `b x t` indicating whether the given prediction should be
                   included in the loss calculation
      :type mask: Tensor
      :param weights: a tensor of shape `b` or `b x 1` containing the per-sample weight
      :type weights: Tensor
      :param lt_mask:
      :type lt_mask: Tensor
      :param gt_mask:
      :type gt_mask: Tensor



   .. py:method:: compute()


.. py:class:: MulticlassMCCMetric(task_weights = 1.0)

   Bases: :py:obj:`MulticlassMCCLoss`


   Calculate a soft Matthews correlation coefficient ([mccWiki]_) loss for multiclass
   classification based on the implementataion of [mccSklearn]_
   .. rubric:: References

   .. [mccWiki] https://en.wikipedia.org/wiki/Phi_coefficient#Multiclass_case
   .. [mccSklearn] https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html


   .. py:attribute:: higher_is_better
      :value: True



   .. py:method:: compute()


.. py:class:: MVELoss(task_weights = 1.0)

   Bases: :py:obj:`ChempropMetric`


   Calculate the loss using Eq. 9 from [nix1994]_

   .. rubric:: References

   .. [nix1994] Nix, D. A.; Weigend, A. S. "Estimating the mean and variance of the target
       probability distribution." Proceedings of 1994 IEEE International Conference on Neural
       Networks, 1994 https://doi.org/10.1109/icnn.1994.374138


.. py:class:: QuantileLoss(task_weights = 1.0, alpha = 0.1)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:attribute:: alpha
      :value: 0.1



   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:class:: R2Score(task_weights = 1.0, **kwargs)

   Bases: :py:obj:`torchmetrics.R2Score`


   Compute r2 score also known as `R2 Score_Coefficient Determination`_.

   .. math:: R^2 = 1 - \frac{SS_{res}}{SS_{tot}}

   where :math:`SS_{res}=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
   :math:`SS_{tot}=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
   adjusted r2 score given by

   .. math:: R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}

   where the parameter :math:`k` (the number of independent regressors) should be provided as the `adjusted` argument.
   The score is only proper defined when :math:`SS_{tot}\neq 0`, which can happen for near constant targets. In this
   case a score of 0 is returned. By definition the score is bounded between :math:`-inf` and 1.0, with 1.0 indicating
   perfect prediction, 0 indicating constant prediction and negative values indicating worse than constant prediction.

   As input to ``forward`` and ``update`` the metric accepts the following input:

   - ``preds`` (:class:`~torch.Tensor`): Predictions from model in float tensor with shape ``(N,)``
     or ``(N, M)`` (multioutput)
   - ``target`` (:class:`~torch.Tensor`): Ground truth values in float tensor with shape ``(N,)``
     or ``(N, M)`` (multioutput)

   As output of ``forward`` and ``compute`` the metric returns the following output:

   - ``r2score`` (:class:`~torch.Tensor`): A tensor with the r2 score(s)

   In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions.
   Please see argument ``multioutput`` for changing this behavior.

   :param num_outputs: Number of outputs in multioutput setting
   :param adjusted: number of independent regressors for calculating adjusted r2 score.
   :param multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings:

                       * ``'raw_values'`` returns full set of scores
                       * ``'uniform_average'`` scores are uniformly averaged
                       * ``'variance_weighted'`` scores are weighted by their individual variances
   :param kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

   .. warning::
       Argument ``num_outputs`` in ``R2Score`` has been deprecated because it is no longer necessary and will be
       removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape
       of the input tensors.

   :raises ValueError: If ``adjusted`` parameter is not an integer larger or equal to 0.
   :raises ValueError: If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.

   Example (single output):
       >>> from torch import tensor
       >>> from torchmetrics.regression import R2Score
       >>> target = tensor([3, -0.5, 2, 7])
       >>> preds = tensor([2.5, 0.0, 2, 8])
       >>> r2score = R2Score()
       >>> r2score(preds, target)
       tensor(0.9486)

   Example (multioutput):
       >>> from torch import tensor
       >>> from torchmetrics.regression import R2Score
       >>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
       >>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
       >>> r2score = R2Score(multioutput='raw_values')
       >>> r2score(preds, target)
       tensor([0.9654, 0.9082])



   .. py:method:: update(preds, targets, mask, *args, **kwargs)

      Update state with predictions and targets.



.. py:class:: Wasserstein(task_weights = 1.0, threshold = None)

   Bases: :py:obj:`ChempropMetric`


   Base class for all metrics present in the Metrics API.

   This class is inherited by all metrics and implements the following functionality:

       1. Handles the transfer of metric states to the correct device.
       2. Handles the synchronization of metric states across processes.
       3. Provides properties and methods to control the overall behavior of the metric and its states.

   The three core methods of the base class are: ``add_state()``, ``forward()`` and ``reset()`` which should almost
   never be overwritten by child classes. Instead, the following methods should be overwritten ``update()`` and
   ``compute()``.

   :param kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

                  - **compute_on_cpu**:
                      If metric state should be stored on CPU during computations. Only works for list states.
                  - **dist_sync_on_step**:
                      If metric state should synchronize on ``forward()``. Default is ``False``.
                  - **process_group**:
                      The process group on which the synchronization is called. Default is the world.
                  - **dist_sync_fn**:
                      Function that performs the allgather option on the metric state. Default is a custom
                      implementation that calls ``torch.distributed.all_gather`` internally.
                  - **distributed_available_fn**:
                      Function that checks if the distributed backend is available. Defaults to a
                      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
                  - **sync_on_compute**:
                      If metric state should synchronize when ``compute`` is called. Default is ``True``.
                  - **compute_with_cache**:
                      If results from ``compute`` should be cached. Default is ``True``.


   .. py:attribute:: threshold
      :value: None



   .. py:method:: extra_repr()

      Return the extra representation of the module.

      To print customized extra information, you should re-implement
      this method in your own modules. Both single-line and multi-line
      strings are acceptable.



.. py:class:: BinaryClassificationFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`BinaryClassificationFFNBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 1


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:method:: train_step(Z)


.. py:class:: BinaryClassificationFFNBase(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`_FFNPredictorBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


.. py:class:: BinaryDirichletFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`BinaryClassificationFFNBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 2


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:method:: train_step(Z)


.. py:class:: EvidentialFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`RegressionFFN`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 4


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:attribute:: train_step


.. py:class:: MulticlassClassificationFFN(n_classes, n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`_FFNPredictorBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 1


      the number of targets `s` to predict for each task `t`


   .. py:attribute:: n_classes


   .. py:property:: n_tasks
      :type: int


      the number of tasks `t` to predict for each input


   .. py:method:: forward(Z)


   .. py:method:: train_step(Z)


.. py:class:: MveFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`RegressionFFN`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 2


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:attribute:: train_step


.. py:class:: Predictor(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`, :py:obj:`chemprop.nn.hparams.HasHParams`


   A :class:`Predictor` is a protocol that defines a differentiable function
   :math:`f` : \mathbb R^d \mapsto \mathbb R^o


   .. py:attribute:: input_dim
      :type:  int

      the input dimension


   .. py:attribute:: output_dim
      :type:  int

      the output dimension


   .. py:attribute:: n_tasks
      :type:  int

      the number of tasks `t` to predict for each input


   .. py:attribute:: n_targets
      :type:  int

      the number of targets `s` to predict for each task `t`


   .. py:attribute:: criterion
      :type:  chemprop.nn.metrics.ChempropMetric

      the loss function to use for training


   .. py:attribute:: task_weights
      :type:  torch.Tensor

      the weights to apply to each task when calculating the loss


   .. py:attribute:: output_transform
      :type:  chemprop.nn.transforms.UnscaleTransform

      the transform to apply to the output of the predictor


   .. py:method:: forward(Z)
      :abstractmethod:



   .. py:method:: train_step(Z)
      :abstractmethod:



   .. py:method:: encode(Z, i)
      :abstractmethod:


      Calculate the :attr:`i`-th hidden representation

      :param Z: a tensor of shape ``n x d`` containing the input data to encode, where ``d`` is the
                input dimensionality.
      :type Z: Tensor
      :param i: The stop index of slice of the MLP used to encode the input. That is, use all
                layers in the MLP *up to* :attr:`i` (i.e., ``MLP[:i]``). This can be any integer
                value, and the behavior of this function is dependent on the underlying list
                slicing behavior. For example:

                * ``i=0``: use a 0-layer MLP (i.e., a no-op)
                * ``i=1``: use only the first block
                * ``i=-1``: use *up to* the final block
      :type i: int

      :returns: a tensor of shape ``n x h`` containing the :attr:`i`-th hidden representation, where
                ``h`` is the number of neurons in the :attr:`i`-th hidden layer.
      :rtype: Tensor



.. py:data:: PredictorRegistry

.. py:class:: QuantileFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`RegressionFFN`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 2


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:attribute:: train_step


.. py:class:: RegressionFFN(n_tasks = 1, input_dim = DEFAULT_HIDDEN_DIM, hidden_dim = 300, n_layers = 1, dropout = 0.0, activation = 'relu', criterion = None, task_weights = None, threshold = None, output_transform = None)

   Bases: :py:obj:`_FFNPredictorBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 1


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:attribute:: train_step


.. py:class:: SpectralFFN(*args, spectral_activation = 'softplus', **kwargs)

   Bases: :py:obj:`_FFNPredictorBase`


   A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an
   underlying :class:`MLP` to map the learned fingerprint to the desired output.


   .. py:attribute:: n_targets
      :value: 1


      the number of targets `s` to predict for each task `t`


   .. py:method:: forward(Z)


   .. py:attribute:: train_step


.. py:class:: GraphTransform(V_transform, E_transform)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: V_transform


   .. py:attribute:: E_transform


   .. py:method:: forward(bmg)


.. py:class:: ScaleTransform(mean, scale, pad = 0)

   Bases: :py:obj:`_ScaleTransformMixin`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:method:: forward(X)


.. py:class:: UnscaleTransform(mean, scale, pad = 0)

   Bases: :py:obj:`_ScaleTransformMixin`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:method:: forward(X)


   .. py:method:: transform_variance(var)


.. py:class:: Activation

   Bases: :py:obj:`chemprop.utils.utils.EnumMapping`


   Enum where members are also (and must be) strings


   .. py:attribute:: RELU


   .. py:attribute:: LEAKYRELU


   .. py:attribute:: PRELU


   .. py:attribute:: TANH


   .. py:attribute:: ELU


