Source code for chemprop.nn.message_passing.multi

import logging
from typing import Iterable, Sequence

from torch import Tensor, nn

from chemprop.data import BatchMolGraph
from chemprop.nn.hparams import HasHParams
from chemprop.nn.message_passing.proto import MessagePassing

logger = logging.getLogger(__name__)


[docs] class MulticomponentMessagePassing(nn.Module, HasHParams): """A `MulticomponentMessagePassing` performs message-passing on each individual input in a multicomponent input then concatenates the representation of each input to construct a global representation Parameters ---------- blocks : Sequence[MessagePassing] the invidual message-passing blocks for each input n_components : int the number of components in each input shared : bool, default=False whether one block will be shared among all components in an input. If not, a separate block will be learned for each component. """ def __init__(self, blocks: Sequence[MessagePassing], n_components: int, shared: bool = False): super().__init__() self.hparams = { "cls": self.__class__, "blocks": [block.hparams for block in blocks], "n_components": n_components, "shared": shared, } if len(blocks) == 0: raise ValueError("arg 'blocks' was empty!") if shared and len(blocks) > 1: logger.warning( "More than 1 block was supplied but 'shared' was True! Using only the 0th block..." ) elif not shared and len(blocks) != n_components: raise ValueError( "arg 'n_components' must be equal to `len(blocks)` if 'shared' is False! " f"got: {n_components} and {len(blocks)}, respectively." ) self.n_components = n_components self.shared = shared self.blocks = nn.ModuleList([blocks[0]] * self.n_components if shared else blocks)
[docs] def __len__(self) -> int: return len(self.blocks)
@property def output_dim(self) -> int: d_o = sum(block.output_dim for block in self.blocks) return d_o
[docs] def forward(self, bmgs: Iterable[BatchMolGraph], V_ds: Iterable[Tensor | None]) -> list[Tensor]: """Encode the multicomponent inputs Parameters ---------- bmgs : Iterable[BatchMolGraph] V_ds : Iterable[Tensor | None] Returns ------- list[Tensor] a list of tensors of shape `V x d_i` containing the respective encodings of the `i`\th component, where `d_i` is the output dimension of the `i`\th encoder """ if V_ds is None: return [block(bmg) for block, bmg in zip(self.blocks, bmgs)] else: return [block(bmg, V_d) for block, bmg, V_d in zip(self.blocks, bmgs, V_ds)]