chemprop.nn.message_passing#

Submodules#

Classes#

AtomMessagePassing

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along

BondMessagePassing

A BondMessagePassing encodes a batch of molecular graphs by passing messages along

MABAtomMessagePassing

A MABAtomMessagePassing encodes a batch of molecular graphs by passing messages

MABBondMessagePassing

A MABBondMessagePassing encodes a batch of molecular graphs by passing messages

MulticomponentMessagePassing

A MulticomponentMessagePassing performs message-passing on each individual input in a

MABMessagePassing

A MABMessagePassing module encodes a batch of molecular graphs

MessagePassing

A MessagePassing module encodes a batch of molecular graphs

Package Contents#

class chemprop.nn.message_passing.AtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: chemprop.nn.message_passing.mixins._AtomMessagePassingMixin, _MessagePassingBase

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along atoms.

It implements the following operation:

\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.message_passing.BondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: chemprop.nn.message_passing.mixins._BondMessagePassingMixin, _MessagePassingBase

A BondMessagePassing encodes a batch of molecular graphs by passing messages along directed bonds.

It implements the following operation:

\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.message_passing.MABAtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#

Bases: chemprop.nn.message_passing.mixins._AtomMessagePassingMixin, _MABMessagePassingBase

A MABAtomMessagePassing encodes a batch of molecular graphs by passing messages along atoms.

It implements the following operation:

\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.message_passing.MABBondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#

Bases: chemprop.nn.message_passing.mixins._BondMessagePassingMixin, _MABMessagePassingBase

A MABBondMessagePassing encodes a batch of molecular graphs by passing messages along directed bonds.

It implements the following operation:

\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.message_passing.MulticomponentMessagePassing(blocks, n_components, shared=False)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MulticomponentMessagePassing performs message-passing on each individual input in a multicomponent input then concatenates the representation of each input to construct a global representation

Parameters:
  • blocks (Sequence[MessagePassing]) – the invidual message-passing blocks for each input

  • n_components (int) – the number of components in each input

  • shared (bool, default=False) – whether one block will be shared among all components in an input. If not, a separate block will be learned for each component.

hparams#
n_components#
shared = False#
blocks#
__len__()[source]#
Return type:

int

property output_dim: int#
Return type:

int

forward(bmgs, V_ds)[source]#

Encode the multicomponent inputs

Parameters:
Returns:

a list of tensors of shape V x d_i containing the respective encodings of the i h component, where d_i is the output dimension of the i h encoder

Return type:

list[Tensor]

class chemprop.nn.message_passing.MABMessagePassing(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MABMessagePassing module encodes a batch of molecular graphs using message passing to learn both vertex-level and edge-level hidden representations.

Parameters:
  • args (Any)

  • kwargs (Any)

output_dims: tuple[int | None, int | None]#
abstractmethod forward(bmg, V_d=None, E_d=None)[source]#

Encode a batch of molecular graphs.

Parameters:
  • bmg (BatchMolGraph) – the batch of MolGraphs to encode

  • V_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.

  • E_d (Tensor | None, default=None) – an optional tensor of shape E x d_ed containing additional descriptors for each directed edge in the batch. These will be concatenated to the learned edge descriptors and transformed before the readout phase. NOTE: There are two directed edges per graph connection. If the extra descriptors are for the connections, each row should be repeated twice in the tensor, once for each direction, potentially using E_d = np.repeat(E_d, repeats=2, axis=0).

Returns:

Two tensors of shape V x d_h or V x (d_h + d_vd) and E x dh or E x (dh + d_ed) containing the hidden representation of each vertex and edge in the batch of graphs. The feature dimension depends on whether additional atom/bond descriptors were provided. If either the vertex or edge hidden representations are not needed, computing the corresponding tensor can be suppresed by setting either return_vertex_embeddings or return_edge_embeddings to False when initializing the module.

Return type:

tuple[Tensor | None, Tensor | None]

class chemprop.nn.message_passing.MessagePassing(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MessagePassing module encodes a batch of molecular graphs using message passing to learn vertex-level hidden representations.

Parameters:
  • args (Any)

  • kwargs (Any)

output_dim: int#
abstractmethod forward(bmg, V_d=None)[source]#

Encode a batch of molecular graphs.

Parameters:
  • bmg (BatchMolGraph) – the batch of MolGraphs to encode

  • V_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.

Returns:

a tensor of shape V x d_h or V x (d_h + d_vd) containing the hidden representation of each vertex in the batch of graphs. The feature dimension depends on whether additional vertex descriptors were provided

Return type:

Tensor