chemprop.nn.message_passing#

Submodules#

Package Contents#

Classes#

MessagePassing

A MessagePassing module encodes a batch of molecular graphs

AtomMessagePassing

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along

BondMessagePassing

A BondMessagePassing encodes a batch of molecular graphs by passing messages along

MulticomponentMessagePassing

A MulticomponentMessagePassing performs message-passing on each individual input in a

class chemprop.nn.message_passing.MessagePassing(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MessagePassing module encodes a batch of molecular graphs using message passing to learn vertex-level hidden representations.

input_dim: int#
output_dim: int#
abstract forward(bmg, V_d=None)[source]#

Encode a batch of molecular graphs.

Parameters:
  • bmg (BatchMolGraph) – the batch of MolGraphs to encode

  • V_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each atom in the batch. These will be concatenated to the learned atomic descriptors and transformed before the readout phase.

Returns:

a tensor of shape V x d_h or V x (d_h + d_vd) containing the hidden representation of each vertex in the batch of graphs. The feature dimension depends on whether additional atom descriptors were provided

Return type:

Tensor

class chemprop.nn.message_passing.AtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: _MessagePassingBase

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along atoms.

It implements the following operation:

\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

initialize(bmg)[source]#

initialize the message passing scheme by calculating initial matrix of hidden features

Parameters:

bmg (chemprop.data.BatchMolGraph)

Return type:

torch.Tensor

message(H, bmg)[source]#

Calculate the message matrix

Parameters:
class chemprop.nn.message_passing.BondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: _MessagePassingBase

A BondMessagePassing encodes a batch of molecular graphs by passing messages along directed bonds.

It implements the following operation:

\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

initialize(bmg)[source]#

initialize the message passing scheme by calculating initial matrix of hidden features

Parameters:

bmg (chemprop.data.BatchMolGraph)

Return type:

torch.Tensor

message(H, bmg)[source]#

Calculate the message matrix

Parameters:
Return type:

torch.Tensor

class chemprop.nn.message_passing.MulticomponentMessagePassing(blocks, n_components, shared=False)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MulticomponentMessagePassing performs message-passing on each individual input in a multicomponent input then concatenates the representation of each input to construct a global representation

Parameters:
  • blocks (Sequence[MessagePassing]) – the invidual message-passing blocks for each input

  • n_components (int) – the number of components in each input

  • shared (bool, default=False) – whether one block will be shared among all components in an input. If not, a separate block will be learned for each component.

property output_dim: int#
Return type:

int

__len__()[source]#
Return type:

int

forward(bmgs, V_ds)[source]#

Encode the multicomponent inputs

Parameters:
Returns:

a list of tensors of shape V x d_i containing the respective encodings of the i h component, where d_i is the output dimension of the i h encoder

Return type:

list[Tensor]