chemprop.nn.message_passing.base#
Classes#
A |
|
A |
Module Contents#
- class chemprop.nn.message_passing.base.BondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#
Bases:
chemprop.nn.message_passing.mixins._BondMessagePassingMixin,_MessagePassingBaseA
BondMessagePassingencodes a batch of molecular graphs by passing messages along directed bonds.It implements the following operation:
\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | torch.nn.Module | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]
- class chemprop.nn.message_passing.base.AtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#
Bases:
chemprop.nn.message_passing.mixins._AtomMessagePassingMixin,_MessagePassingBaseA
AtomMessagePassingencodes a batch of molecular graphs by passing messages along atoms.It implements the following operation:
\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.
- Parameters:
d_v (int)
d_e (int)
d_h (int)
bias (bool)
depth (int)
dropout (float)
activation (str | torch.nn.Module | chemprop.nn.utils.Activation)
undirected (bool)
d_vd (int | None)
V_d_transform (chemprop.nn.transforms.ScaleTransform | None)
graph_transform (chemprop.nn.transforms.GraphTransform | None)
- setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#
setup the weight matrices used in the message passing update functions
- Parameters:
d_v (int) – the vertex feature dimension
d_e (int) – the edge feature dimension
d_h (int, default=300) – the hidden dimension during message passing
d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any
bias (bool, default=False) – whether to add a learned bias to the matrices
- Returns:
W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied
- Return type:
tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]