chemprop.models.multi
=====================

.. py:module:: chemprop.models.multi


Attributes
----------

.. autoapisummary::

   chemprop.models.multi.logger


Classes
-------

.. autoapisummary::

   chemprop.models.multi.MulticomponentMPNN


Module Contents
---------------

.. py:data:: logger

.. py:class:: MulticomponentMPNN(message_passing, agg, predictor, batch_norm = False, metrics = None, warmup_epochs = 2, init_lr = 0.0001, max_lr = 0.001, final_lr = 0.0001, X_d_transform = None)

   Bases: :py:obj:`chemprop.models.model.MPNN`


   An :class:`MPNN` is a sequence of message passing layers, an aggregation routine, and a
   predictor routine.

   The first two modules calculate learned fingerprints from an input molecule or
   reaction graph, and the final module takes these learned fingerprints as input to calculate a
   final prediction. I.e., the following operation:

   .. math::
       \mathtt{MPNN}(\mathcal{G}) =
           \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))

   The full model is trained end-to-end.

   :param message_passing: the message passing block to use to calculate learned fingerprints
   :type message_passing: MessagePassing
   :param agg: the aggregation operation to use during molecule-level prediction
   :type agg: Aggregation
   :param predictor: the function to use to calculate the final prediction
   :type predictor: Predictor
   :param batch_norm: if `True`, apply batch normalization to the output of the aggregation operation
   :type batch_norm: bool, default=False
   :param metrics: the metrics to use to evaluate the model during training and evaluation
   :type metrics: Iterable[Metric] | None, default=None
   :param warmup_epochs: the number of epochs to use for the learning rate warmup
   :type warmup_epochs: int, default=2
   :param init_lr: the initial learning rate
   :type init_lr: int, default=1e-4
   :param max_lr: the maximum learning rate
   :type max_lr: float, default=1e-3
   :param final_lr: the final learning rate
   :type final_lr: float, default=1e-4

   :raises ValueError: if the output dimension of the message passing block does not match the input dimension of
       the predictor function


   .. py:attribute:: message_passing
      :type:  chemprop.nn.MulticomponentMessagePassing


   .. py:method:: fingerprint(bmgs, V_ds, X_d = None)

      the learned fingerprints for the input molecules



   .. py:method:: on_validation_model_eval()

      Called when the validation loop starts.

      The validation loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
      to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_train`.




   .. py:method:: get_batch_size(batch)


