Message passing#

[1]:
from chemprop.nn.message_passing.base import BondMessagePassing, AtomMessagePassing

This is an example dataloader to make inputs for the message passing layer.

[2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset, build_dataloader

smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = build_dataloader(dataset)

Message passing schemes#

There are two message passing schemes. Chemprop prefers a D-MPNN scheme (BondMessagePassing) where messages are passed between directed edges (bonds) rather than between nodes (atoms) as would be done in a traditional MPNN (AtomMessagePassing).

[3]:
mp = AtomMessagePassing()
mp = BondMessagePassing()

Input dimensions#

By default, the bond message passing layer’s input dimension is the sum of atom and bond features from the default atom and bond featurizers. If you use a custom featurizer, the message passing layer needs to be told when it is created.

Also note that an atom message passing’s default input dimension is the length of the atom features from the default atom featurizer.

[4]:
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer

n_atom_features, n_bond_features = SimpleMoleculeMolGraphFeaturizer().shape
(n_atom_features + n_bond_features) == mp.W_i.in_features
[4]:
True
[5]:
from chemprop.featurizers import MultiHotAtomFeaturizer

n_extra_bond_features = 12
featurizer = SimpleMoleculeMolGraphFeaturizer(
    atom_featurizer=MultiHotAtomFeaturizer.organic(), extra_bond_fdim=n_extra_bond_features
)

mp = BondMessagePassing(d_v=featurizer.atom_fdim, d_e=featurizer.bond_fdim)

If extra atom descriptors are used, the message passing layer also needs to be told. A separate weight matrix is created and applied to the concatenated hidden representation and extra descriptors after message passing is complete. The output dimension of the message passing layer is the sum of the hidden size and number of extra atom descriptors.

[6]:
n_extra_atom_descriptors = 28
mp = BondMessagePassing(d_vd=n_extra_atom_descriptors)
mp.output_dim
[6]:
328

Customization#

The following hyperparameters of the message passing layer are customizable:

  • the hidden dimension during message passing, default: 300

  • whether a bias term used, default: False

  • the number of message passing iterations, default: 3

  • whether to pass messages on undirected edges, default: False

  • the dropout probability, default: 0.0 (i.e. no dropout)

  • which activation function, default: ReLU

[7]:
mp = BondMessagePassing(
    d_h=600, bias=True, depth=5, undirected=True, dropout=0.5, activation="tanh"
)

The output of message passing is a torch tensor of shape # of atoms in batch x length of hidden representation.

[8]:
batch_molgraph, extra_atom_descriptors, *_ = next(iter(dataloader))
hidden_atom_representations = mp(batch_molgraph, extra_atom_descriptors)
hidden_atom_representations.shape
[8]:
torch.Size([6, 600])