Multicomponent models#
[1]:
from chemprop.nn.message_passing import MulticomponentMessagePassing
from chemprop.models import MulticomponentMPNN
Overview#
The basic Chemprop model is designed for a single molecule or reaction as input. A multicomponent Chemprop model organizes these basic building blocks to take multiple molecules/reactions as input. This is useful for properties that depend on multiple components like properties in solvents.
Message passing#
MulticomponentMessagePassing organizes the single component message passing modules for each component in the multicomponent dataset. The individual message passing modules can be unique for each component, shared between some components, or shared between all components. If all components share the same message passing module, the shared flag can be set to True. Note that it doesn’t make sense for components that use different featurizers (e.g. molecules and
reactions) to use the same message passing module.
[2]:
from chemprop.nn import BondMessagePassing
mp1 = BondMessagePassing(d_h=100)
mp2 = BondMessagePassing(d_h=600)
blocks = [mp1, mp2]
mcmp = MulticomponentMessagePassing(blocks=blocks, n_components=len(blocks))
mp = BondMessagePassing()
mcmp = MulticomponentMessagePassing(blocks=[mp], n_components=2, shared=True)
During the forward pass of the model, the output of each message passing block is concatentated after aggregation as input to the predictor.
Aggregation#
A single aggregation module is used on all message passing outputs.
[3]:
from chemprop.nn import MeanAggregation
agg = MeanAggregation()
Predictor#
The predictor needs to be told the output dimension of the message passing layer.
[4]:
from chemprop.nn import RegressionFFN
ffn = RegressionFFN(input_dim=mcmp.output_dim)
Multicomponent MPNN#
The submodules are composed together in a MulticomponentMPNN model.
[ ]:
mc_model = MulticomponentMPNN(mcmp, agg, ffn)
mc_model
MulticomponentMPNN(
(message_passing): MulticomponentMessagePassing(
(blocks): ModuleList(
(0-1): 2 x BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
)
)
(agg): MeanAggregation()
(bn): Identity()
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): Identity()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)