chemprop.models.mol_atom_bond
=============================

.. py:module:: chemprop.models.mol_atom_bond


Attributes
----------

.. autoapisummary::

   chemprop.models.mol_atom_bond.logger


Classes
-------

.. autoapisummary::

   chemprop.models.mol_atom_bond.MolAtomBondMPNN


Module Contents
---------------

.. py:data:: logger

.. py:class:: MolAtomBondMPNN(message_passing, agg = None, mol_predictor = None, atom_predictor = None, bond_predictor = None, atom_constrainer = None, bond_constrainer = None, batch_norm = False, metrics = None, warmup_epochs = 2, init_lr = 0.0001, max_lr = 0.001, final_lr = 0.0001, X_d_transform = None)

   Bases: :py:obj:`lightning.pytorch.LightningModule`


   An :class:`MolAtomBondMPNN` is a sequence of message passing layers, an aggregation routine,
   and up to three predictor routines for molecule-level, atom-level, and bond-level predictions.

   The first two modules calculate learned graph, node, and edge embeddings from an input molecule
   graph, and the final modules take these learned fingerprints as input to calculate a
   final prediction. I.e., the following operation:

   .. math::
       \mathtt{MPNN}(\mathcal{G}) =
           \mathtt{predictor}(\mathtt{agg}(\mathtt{message\_passing}(\mathcal{G})))

   The full model is trained end-to-end.

   :param message_passing: the message passing block to use to calculate learned fingerprints
   :type message_passing: MABMessagePassing
   :param agg: the aggregation operation to use during molecule-level prediction
   :type agg: Aggregation | None, default=None
   :param mol_predictor: the function to use to calculate the final molecule-level prediction
   :type mol_predictor: Predictor | None, default=None
   :param atom_predictor: the function to use to calculate the final atom-level prediction
   :type atom_predictor: Predictor | None, default=None
   :param bond_predictor: the function to use to calculate the final bond-level prediction
   :type bond_predictor: Predictor | None, default=None
   :param atom_constrainer: the constrainer to use to constrain the atom-level predictions
   :type atom_constrainer: ConstrainerFFN | None, default=None
   :param bond_constrainer: the constrainer to use to constrain the bond-level predictions
   :type bond_constrainer: ConstrainerFFN | None, default=None
   :param batch_norm: if `True`, apply batch normalization to the learned fingerprints before passing them to the
                      predictors
   :type batch_norm: bool, default=False
   :param metrics: the metrics to use to evaluate the model during training and evaluation
   :type metrics: Iterable[Metric] | None, default=None
   :param warmup_epochs: the number of epochs to use for the learning rate warmup
   :type warmup_epochs: int, default=2
   :param init_lr: the initial learning rate
   :type init_lr: int, default=1e-4
   :param max_lr: the maximum learning rate
   :type max_lr: float, default=1e-3
   :param final_lr: the final learning rate
   :type final_lr: float, default=1e-4

   :raises ValueError: if the output dimension of the message passing block does not match the input dimension of
       the predictor functions
   :raises ValueError: if none of the predictor functions are provided
   :raises ValueError: if `mol_predictor` is provided but `agg` is not


   .. py:attribute:: message_passing


   .. py:attribute:: agg
      :value: None



   .. py:attribute:: mol_predictor
      :value: None



   .. py:attribute:: atom_predictor
      :value: None



   .. py:attribute:: atom_constrainer
      :value: None



   .. py:attribute:: bond_predictor
      :value: None



   .. py:attribute:: bond_constrainer
      :value: None



   .. py:attribute:: predictors


   .. py:attribute:: bns


   .. py:attribute:: X_d_transform


   .. py:attribute:: metricss


   .. py:attribute:: warmup_epochs
      :value: 2



   .. py:attribute:: init_lr
      :value: 0.0001



   .. py:attribute:: max_lr
      :value: 0.001



   .. py:attribute:: final_lr
      :value: 0.0001



   .. py:property:: output_dimss
      :type: tuple[int | None, int | None, int | None]



   .. py:property:: n_taskss
      :type: tuple[int | None, int | None, int | None]



   .. py:property:: n_targetss
      :type: tuple[int | None, int | None, int | None]



   .. py:property:: criterions
      :type: tuple[chemprop.nn.ChempropMetric | None, chemprop.nn.ChempropMetric | None, chemprop.nn.ChempropMetric | None]



   .. py:method:: fingerprint(bmg, V_d = None, E_d = None, X_d = None)

      the learned fingerprints for the input molecules



   .. py:method:: encoding(bmg, V_d = None, E_d = None, X_d = None, i = -1)

      Calculate the :attr:`i`-th hidden representation



   .. py:method:: forward(bmg, V_d = None, E_d = None, X_d = None, constraints = None)

      Generate predictions for the input molecules/reactions



   .. py:method:: training_step(batch, batch_idx)

      Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
      logger.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
                  automatic optimization.
                - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
                  multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
                  the loss is not required.

      In this step you'd normally do the forward pass and calculate the loss for a batch.
      You can also do fancier things like multiple forward passes or something model specific.

      Example::

          def training_step(self, batch, batch_idx):
              x, y, z = batch
              out = self.encoder(x)
              loss = self.loss(out, x)
              return loss

      To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

      .. code-block:: python

          def __init__(self):
              super().__init__()
              self.automatic_optimization = False


          # Multiple optimizers (e.g.: GANs)
          def training_step(self, batch, batch_idx):
              opt1, opt2 = self.optimizers()

              # do training_step with encoder
              ...
              opt1.step()
              # do training_step with decoder
              ...
              opt2.step()

      .. note::

         When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
         normalized by ``accumulate_grad_batches`` internally.



   .. py:method:: on_validation_model_eval()

      Called when the validation loop starts.

      The validation loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
      to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_train`.




   .. py:method:: validation_step(batch, batch_idx = 0)

      Operates on a single batch of data from the validation set. In this step you'd might generate examples or
      calculate anything of interest like accuracy.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
                - ``None`` - Skip to the next batch.

      .. code-block:: python

          # if you have one val dataloader:
          def validation_step(self, batch, batch_idx): ...


          # if you have multiple val dataloaders:
          def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

      Examples::

          # CASE 1: A single validation dataset
          def validation_step(self, batch, batch_idx):
              x, y = batch

              # implement your own
              out = self(x)
              loss = self.loss(out, y)

              # log 6 example images
              # or generated text... or whatever
              sample_imgs = x[:6]
              grid = torchvision.utils.make_grid(sample_imgs)
              self.logger.experiment.add_image('example_images', grid, 0)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs!
              self.log_dict({'val_loss': loss, 'val_acc': val_acc})

      If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
      setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

      .. code-block:: python

          # CASE 2: multiple validation dataloaders
          def validation_step(self, batch, batch_idx, dataloader_idx=0):
              # dataloader_idx tells you which dataset this is.
              x, y = batch

              # implement your own
              out = self(x)

              if dataloader_idx == 0:
                  loss = self.loss0(out, y)
              else:
                  loss = self.loss1(out, y)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs separately for each dataloader
              self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

      .. note:: If you don't need to validate you don't need to implement this method.

      .. note::

         When the :meth:`validation_step` is called, the model has been put in eval mode
         and PyTorch gradients have been disabled. At the end of validation,
         the model goes back to training mode and gradients are enabled.



   .. py:method:: test_step(batch, batch_idx = 0)

      Operates on a single batch of data from the test set. In this step you'd normally generate examples or
      calculate anything of interest such as accuracy.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
                - ``None`` - Skip to the next batch.

      .. code-block:: python

          # if you have one test dataloader:
          def test_step(self, batch, batch_idx): ...


          # if you have multiple test dataloaders:
          def test_step(self, batch, batch_idx, dataloader_idx=0): ...

      Examples::

          # CASE 1: A single test dataset
          def test_step(self, batch, batch_idx):
              x, y = batch

              # implement your own
              out = self(x)
              loss = self.loss(out, y)

              # log 6 example images
              # or generated text... or whatever
              sample_imgs = x[:6]
              grid = torchvision.utils.make_grid(sample_imgs)
              self.logger.experiment.add_image('example_images', grid, 0)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs!
              self.log_dict({'test_loss': loss, 'test_acc': test_acc})

      If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
      setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

      .. code-block:: python

          # CASE 2: multiple test dataloaders
          def test_step(self, batch, batch_idx, dataloader_idx=0):
              # dataloader_idx tells you which dataset this is.
              x, y = batch

              # implement your own
              out = self(x)

              if dataloader_idx == 0:
                  loss = self.loss0(out, y)
              else:
                  loss = self.loss1(out, y)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs separately for each dataloader
              self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

      .. note:: If you don't need to test you don't need to implement this method.

      .. note::

         When the :meth:`test_step` is called, the model has been put in eval mode and
         PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
         to training mode and gradients are enabled.



   .. py:method:: predict_step(batch, batch_idx, dataloader_idx = 0)

      Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
      :meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.

      The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
      to scale inference on multi-devices.

      To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
      callback to write the predictions to disk or database after each batch or on epoch end.

      The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
      based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
      or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns: Predicted output (optional).

      Example ::

          class MyModel(LightningModule):

              def predict_step(self, batch, batch_idx, dataloader_idx=0):
                  return self(batch)

          dm = ...
          model = MyModel()
          trainer = Trainer(accelerator="gpu", devices=2)
          predictions = trainer.predict(model, dm)




   .. py:method:: configure_optimizers()

      Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one.
      But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in
      the manual optimization mode.

      :returns: Any of these 6 options.

                - **Single optimizer**.
                - **List or Tuple** of optimizers.
                - **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
                  (or multiple ``lr_scheduler_config``).
                - **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
                  key whose value is a single LR scheduler or ``lr_scheduler_config``.
                - **None** - Fit will run without any optimizer.

      The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
      The default configuration is shown below.

      .. code-block:: python

          lr_scheduler_config = {
              # REQUIRED: The scheduler instance
              "scheduler": lr_scheduler,
              # The unit of the scheduler's step size, could also be 'step'.
              # 'epoch' updates the scheduler on epoch end whereas 'step'
              # updates it after a optimizer update.
              "interval": "epoch",
              # How many epochs/steps should pass between calls to
              # `scheduler.step()`. 1 corresponds to updating the learning
              # rate after every epoch/step.
              "frequency": 1,
              # Metric to monitor for schedulers like `ReduceLROnPlateau`
              "monitor": "val_loss",
              # If set to `True`, will enforce that the value specified 'monitor'
              # is available when the scheduler is updated, thus stopping
              # training if not found. If set to `False`, it will only produce a warning
              "strict": True,
              # If using the `LearningRateMonitor` callback to monitor the
              # learning rate progress, this keyword can be used to specify
              # a custom logged name
              "name": None,
          }

      When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the
      :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the
      ``lr_scheduler_config`` contains the keyword ``"monitor"`` set to the metric name that the scheduler
      should be conditioned on.

      .. testcode::

          # The ReduceLROnPlateau scheduler requires a monitor
          def configure_optimizers(self):
              optimizer = Adam(...)
              return {
                  "optimizer": optimizer,
                  "lr_scheduler": {
                      "scheduler": ReduceLROnPlateau(optimizer, ...),
                      "monitor": "metric_to_track",
                      "frequency": "indicates how often the metric is updated",
                      # If "monitor" references validation metrics, then "frequency" should be set to a
                      # multiple of "trainer.check_val_every_n_epoch".
                  },
              }


          # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
          def configure_optimizers(self):
              optimizer1 = Adam(...)
              optimizer2 = SGD(...)
              scheduler1 = ReduceLROnPlateau(optimizer1, ...)
              scheduler2 = LambdaLR(optimizer2, ...)
              return (
                  {
                      "optimizer": optimizer1,
                      "lr_scheduler": {
                          "scheduler": scheduler1,
                          "monitor": "metric_to_track",
                      },
                  },
                  {"optimizer": optimizer2, "lr_scheduler": scheduler2},
              )

      Metrics can be made available to monitor by simply logging it using
      ``self.log('metric_to_track', metric_val)`` in your :class:`~lightning.pytorch.core.LightningModule`.

      .. note::

         Some things to know:
         
         - Lightning calls ``.backward()`` and ``.step()`` automatically in case of automatic optimization.
         - If a learning rate scheduler is specified in ``configure_optimizers()`` with key
           ``"interval"`` (default "epoch") in the scheduler configuration, Lightning will call
           the scheduler's ``.step()`` method automatically in case of automatic optimization.
         - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizer.
         - If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you.
         - If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them
           yourself.
         - If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook.



   .. py:method:: load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)
      :classmethod:


      Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
      passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.

      Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.

      :param checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
      :param map_location: If your checkpoint saved a GPU model and you now load on CPUs
                           or a different number of GPUs, use this to map to the new setup.
                           The behaviour is the same as in :func:`torch.load`.
      :param hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
                           as in this example::

                               drop_prob: 0.2
                               dataloader:
                                   batch_size: 32

                           You most likely won't need this since Lightning will always save the hyperparameters
                           to the checkpoint.
                           However, if your checkpoint weights don't have the hyperparameters saved,
                           use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
                           These will be converted into a :class:`~dict` and passed into your
                           :class:`LightningModule` for use.

                           If your model's ``hparams`` argument is :class:`~argparse.Namespace`
                           and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
                           ``hparams`` as :class:`~dict`.
      :param strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
                     returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is
                     set, in which case it defaults to the value of ``LightningModule.strict_loading``.
      :param weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
                           primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
                           ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
                           ``weights_only=True``. For more information, please refer to the
                           `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
      :param \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
                        hyperparameter values.

      :returns: :class:`LightningModule` instance with loaded weights and hyperparameters (if available).

      .. note::

         ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
         **class** to call it instead of the :class:`LightningModule` instance, or a
         ``TypeError`` will be raised.

      .. note::

         To ensure all layers can be loaded from the checkpoint, this function will call
         :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
         model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
         not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
         case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.

      Example::

          # load weights without mapping ...
          model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

          # or load weights mapping all weights from GPU 1 to GPU 0 ...
          map_location = {'cuda:1':'cuda:0'}
          model = MyLightningModule.load_from_checkpoint(
              'path/to/checkpoint.ckpt',
              map_location=map_location
          )

          # or load weights and hyperparameters from separate files.
          model = MyLightningModule.load_from_checkpoint(
              'path/to/checkpoint.ckpt',
              hparams_file='/path/to/hparams_file.yaml'
          )

          # override some of the params with new values
          model = MyLightningModule.load_from_checkpoint(
              PATH,
              num_layers=128,
              pretrained_ckpt_path=NEW_PATH,
          )

          # predict
          pretrained_model.eval()
          pretrained_model.freeze()
          y_hat = pretrained_model(x)




   .. py:method:: load_from_file(model_path, map_location=None, strict=True, **submodules)
      :classmethod:



