chemprop.nn.message_passing#
Submodules#
Classes#
A |
|
A |
|
A |
|
A |
|
A MulticomponentMessagePassing performs message-passing on each individual input in a |
|
A |
|
A |
Package Contents#
- class chemprop.nn.message_passing.AtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#
Bases:
chemprop.nn.message_passing.mixins._AtomMessagePassingMixin,_MessagePassingBaseA
AtomMessagePassingencodes a batch of molecular graphs by passing messages along atoms.It implements the following operation:
\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | torch.nn.Module | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
- class chemprop.nn.message_passing.BondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#
Bases:
chemprop.nn.message_passing.mixins._BondMessagePassingMixin,_MessagePassingBaseA
BondMessagePassingencodes a batch of molecular graphs by passing messages along directed bonds.It implements the following operation:
\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | torch.nn.Module | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
- class chemprop.nn.message_passing.MABAtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#
Bases:
chemprop.nn.message_passing.mixins._AtomMessagePassingMixin,_MABMessagePassingBaseA
MABAtomMessagePassingencodes a batch of molecular graphs by passing messages along atoms.It implements the following operation:
\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
d_ed (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
E_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
return_vertex_embeddings (bool)
return_edge_embeddings (bool)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
- class chemprop.nn.message_passing.MABBondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#
Bases:
chemprop.nn.message_passing.mixins._BondMessagePassingMixin,_MABMessagePassingBaseA
MABBondMessagePassingencodes a batch of molecular graphs by passing messages along directed bonds.It implements the following operation:
\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
d_ed (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
E_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
return_vertex_embeddings (bool)
return_edge_embeddings (bool)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
- class chemprop.nn.message_passing.MulticomponentMessagePassing(blocks, n_components, shared=False)[source]#
Bases:
torch.nn.Module,chemprop.nn.hparams.HasHParamsA MulticomponentMessagePassing performs message-passing on each individual input in a multicomponent input then concatenates the representation of each input to construct a global representation
- Parameters:
blocks (Sequence[MessagePassing]) – the invidual message-passing blocks for each input
n_components (int) – the number of components in each input
shared (bool, default=False) – whether one block will be shared among all components in an input. If not, a separate block will be learned for each component.
- hparams#
- n_components#
- blocks#
- property output_dim: int#
- Return type:
int
- forward(bmgs, V_ds)[source]#
Encode the multicomponent inputs
- Parameters:
bmgs (Iterable[BatchMolGraph])
V_ds (Iterable[Tensor | None])
- Returns:
a list of tensors of shape V x d_i containing the respective encodings of the i h component, where d_i is the output dimension of the i h encoder
- Return type:
list[Tensor]
- class chemprop.nn.message_passing.MABMessagePassing(*args, **kwargs)[source]#
Bases:
torch.nn.Module,chemprop.nn.hparams.HasHParamsA
MABMessagePassingmodule encodes a batch of molecular graphs using message passing to learn both vertex-level and edge-level hidden representations.- Parameters:
args (Any)
kwargs (Any)
- output_dims: tuple[int | None, int | None]#
- abstractmethod forward(bmg, V_d=None, E_d=None)[source]#
Encode a batch of molecular graphs.
- Parameters:
bmg (BatchMolGraph) – the batch of
MolGraphs to encodeV_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.
E_d (Tensor | None, default=None) – an optional tensor of shape E x d_ed containing additional descriptors for each directed edge in the batch. These will be concatenated to the learned edge descriptors and transformed before the readout phase. NOTE: There are two directed edges per graph connection. If the extra descriptors are for the connections, each row should be repeated twice in the tensor, once for each direction, potentially using
E_d = np.repeat(E_d, repeats=2, axis=0).
- Returns:
Two tensors of shape V x d_h or V x (d_h + d_vd) and E x dh or E x (dh + d_ed) containing the hidden representation of each vertex and edge in the batch of graphs. The feature dimension depends on whether additional atom/bond descriptors were provided. If either the vertex or edge hidden representations are not needed, computing the corresponding tensor can be suppresed by setting either return_vertex_embeddings or return_edge_embeddings to False when initializing the module.
- Return type:
tuple[Tensor | None, Tensor | None]
- class chemprop.nn.message_passing.MessagePassing(*args, **kwargs)[source]#
Bases:
torch.nn.Module,chemprop.nn.hparams.HasHParamsA
MessagePassingmodule encodes a batch of molecular graphs using message passing to learn vertex-level hidden representations.- Parameters:
args (Any)
kwargs (Any)
- output_dim: int#
- abstractmethod forward(bmg, V_d=None)[source]#
Encode a batch of molecular graphs.
- Parameters:
bmg (BatchMolGraph) – the batch of
MolGraphs to encodeV_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.
- Returns:
a tensor of shape V x d_h or V x (d_h + d_vd) containing the hidden representation of each vertex in the batch of graphs. The feature dimension depends on whether additional vertex descriptors were provided
- Return type:
Tensor