Constrained Atom and Bond Prediction#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
[2]:
import ast
from pathlib import Path
from lightning import pytorch as pl
import numpy as np
import pandas as pd
import torch
from chemprop import data, featurizers, models, nn
chemprop_dir = Path.cwd().parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"
If any of the atom or bond properties should sum to a known molecule level value, we can constrain the atom and bond predictions to sum to that value. For example, atom partial charges should sum to the total charge of the molecule.
Make datapoints#
[3]:
df_input = pd.read_csv(data_dir / "constrained_regression.csv")
df_input
[3]:
| smiles | mol_y | atom_y1 | atom_y2 | bond_y1 | bond_y2 | |
|---|---|---|---|---|---|---|
| 0 | [H][H] | 0 | [0, 0] | [1.008, 1.008] | [2] | [2] |
| 1 | C | 0 | [0] | [12.011] | [] | [] |
| 2 | CN | 0 | [0, 0] | [12.011, 14.007] | [13] | [2] |
| 3 | CN | 0 | [0, 0] | [12.011, 14.007] | [13] | [2] |
| 4 | CC | 0 | [0, 0] | [12.011, 12.011] | [12] | [2] |
| 5 | [CH2:3]=[N+:1]([H:4])[H:2] | 1 | [1, 0, 0, 0] | [14.007, 1.008, 12.011, 1.008] | [13, 8, 8] | [4, 2, 2] |
| 6 | CCCC | 0 | [0, 0, 0, 0] | [12.011, 12.011, 12.011, 12.011] | [12, 12, 12] | [2, 2, 2] |
| 7 | CO | 0 | [0, 0] | [12.011, 15.999] | [14] | [2] |
| 8 | CC#N | 0 | [0, 0, 0] | [12.011, 12.011, 14.007] | [12, 13] | [2, 6] |
| 9 | C1NN1 | 0 | [0, 0, 0] | [12.011, 14.007, 14.007] | [13, 14, 13] | [2, 2, 2] |
| 10 | c1cc[n-]c1 | -1 | [0, 0, 0, -1, 0] | [12.011, 12.011, 12.011, 14.007, 12.011] | [12, 12, 13, 13, 12] | [3, 3, 3, 3, 3] |
[4]:
columns = ["smiles", "mol_y", "atom_y1", "atom_y2", "bond_y1", "bond_y2"]
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:2]].values
atoms_ys = df_input.loc[:, columns[2:4]].values
bonds_ys = df_input.loc[:, columns[4:6]].values
atoms_ys = [
np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
for atom_ys in atoms_ys
]
bonds_ys = [
np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
for bond_ys in bonds_ys
]
Load constraints#
Not all atom and bond predictions need to be constrained. Here both atom predictions are constrained and only one of the bond predictions is constrained.
[5]:
df_constraints = pd.read_csv(data_dir / "constrained_regression_constraints.csv")
df_constraints
[5]:
| atom_y1_constraint | atom_y2_constraint | bond_y2_constraint | |
|---|---|---|---|
| 0 | 0 | 2.016 | 2 |
| 1 | 0 | 12.011 | 0 |
| 2 | 0 | 26.018 | 2 |
| 3 | 0 | 26.018 | 2 |
| 4 | 0 | 24.022 | 2 |
| 5 | 1 | 28.034 | 8 |
| 6 | 0 | 48.044 | 6 |
| 7 | 0 | 28.010 | 2 |
| 8 | 0 | 38.029 | 8 |
| 9 | 0 | 40.025 | 6 |
| 10 | -1 | 62.051 | 15 |
[6]:
n_mols = len(df_constraints)
# A dictionary to map the atom and bond target columns to the corresponding constraint column
constraints_cols_to_target_cols = {
"atom_y1": 0,
"atom_y2": 1,
"bond_y2": 2,
}
# Target columns without constraints have their constraints set to np.nan
atom_constraint_cols = [
constraints_cols_to_target_cols.get(col)
for col in columns[2:4]
]
atom_constraints = np.hstack(
[
df_constraints.iloc[:, col].values.reshape(-1, 1)
if col is not None
else np.full((n_mols, 1), np.nan)
for col in atom_constraint_cols
]
)
bond_constraint_cols = [
constraints_cols_to_target_cols.get(col)
for col in columns[4:6]
]
bond_constraints = np.hstack(
[
df_constraints.iloc[:, col].values.reshape(-1, 1)
if col is not None
else np.full((n_mols, 1), np.nan)
for col in bond_constraint_cols
]
)
[7]:
datapoints = [
data.MolAtomBondDatapoint.from_smi(
smi,
keep_h=True,
add_h=False,
reorder_atoms=True,
y=mol_ys[i],
atom_y=atoms_ys[i],
bond_y=bonds_ys[i],
atom_constraint=atom_constraints[i],
bond_constraint=bond_constraints[i],
)
for i, smi in enumerate(smis)
]
train_dataset = data.MolAtomBondDataset(datapoints)
val_dataset = data.MolAtomBondDataset(datapoints)
test_dataset = data.MolAtomBondDataset(datapoints)
predict_dataset = data.MolAtomBondDataset(datapoints)
# If the atom/bond targets are scaled, the corresponding constraints are also scaled automatically.
atom_target_scaler = train_dataset.normalize_targets("atom")
val_dataset.normalize_targets("atom", atom_target_scaler)
atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)
train_dataloader = data.build_dataloader(train_dataset, shuffle=True)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False)
Set up model#
[8]:
mp = nn.MABBondMessagePassing()
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1], output_transform=atom_target_transform)
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])
Each atom/bond prediction for a constrained target is adjusted so they sum to the constraint. The amount each individual prediction is adjusted is determined from the node/edge fingerprints using a separate feed forward network.
[9]:
atom_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(atom_constraints[0])).sum())
bond_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(bond_constraints[0])).sum(), fp_dim=600)
[10]:
model = models.MolAtomBondMPNN(
message_passing=mp,
agg=agg,
mol_predictor=mol_predictor,
atom_predictor=atom_predictor,
bond_predictor=bond_predictor,
atom_constrainer=atom_constrainer,
bond_constrainer=bond_constrainer,
)
[11]:
model
[11]:
MolAtomBondMPNN(
(message_passing): MABBondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_vo): Linear(in_features=372, out_features=300, bias=True)
(W_eo): Linear(in_features=314, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(E_d_transform): Identity()
(graph_transform): Identity()
)
(agg): NormAggregation()
(mol_predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): Identity()
)
(atom_predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=2, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0, 1.0]])
(output_transform): UnscaleTransform()
)
(atom_constrainer): ConstrainerFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=2, bias=True)
)
)
)
(bond_predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=2, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0, 1.0]])
(output_transform): Identity()
)
(bond_constrainer): ConstrainerFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
)
(bns): ModuleList(
(0-2): 3 x Identity()
)
(X_d_transform): Identity()
(metricss): ModuleList(
(0): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
(1-2): 2 x ModuleList(
(0): MSE(task_weights=[[1.0]])
(1): MSE(task_weights=[[1.0, 1.0]])
)
)
)
The atom and bond predictions obey the constraints#
[12]:
batch = next(iter(predict_dataloader))
bmg, V_d, E_d, X_d, *_, constraints = batch
with torch.no_grad():
mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)
[13]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
atom_preds = torch.split(atom_preds_tensor, atoms_per_mol)
errors = predict_dataset.atom_constraints - torch.vstack([p.sum(dim=0) for p in atom_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))
[[ 0.00000000e+00 -3.24249267e-08]
[ 0.00000000e+00 3.20434570e-07]
[ 7.45058060e-09 3.50952149e-07]
[ 7.45058060e-09 3.50952149e-07]
[ 0.00000000e+00 6.40869139e-07]
[ 0.00000000e+00 -3.96728517e-07]
[-1.49011612e-08 1.28173828e-06]
[-7.45058060e-09 -2.13623047e-06]
[ 7.45058060e-09 6.71386722e-07]
[ 0.00000000e+00 -1.52587891e-06]
[ 0.00000000e+00 -2.50244140e-06]]
[14]:
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]
bond_preds = torch.split(bond_preds_tensor, bonds_per_mol)
errors = predict_dataset.bond_constraints - torch.vstack([p.sum(dim=0) for p in bond_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))
[[ nan 0.00000000e+00]
[ nan 0.00000000e+00]
[ nan 1.19209290e-07]
[ nan 1.19209290e-07]
[ nan 0.00000000e+00]
[ nan 0.00000000e+00]
[ nan 0.00000000e+00]
[ nan 0.00000000e+00]
[ nan -1.90734863e-06]
[ nan 0.00000000e+00]
[ nan -1.90734863e-06]]
Fit the model#
[15]:
trainer = pl.Trainer(
logger=False,
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20,
)
๐ก Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[16]:
trainer.fit(model, train_dataloader, val_dataloader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
โโโโโณโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโณโโโโโโโโณโโโโโโโโ โ โ Name โ Type โ Params โ Mode โ FLOPs โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ 0 โ message_passing โ MABBondMessagePassing โ 322 K โ train โ 0 โ โ 1 โ agg โ NormAggregation โ 0 โ train โ 0 โ โ 2 โ mol_predictor โ RegressionFFN โ 90.6 K โ train โ 0 โ โ 3 โ atom_predictor โ RegressionFFN โ 90.9 K โ train โ 0 โ โ 4 โ atom_constrainer โ ConstrainerFFN โ 90.9 K โ train โ 0 โ โ 5 โ bond_predictor โ RegressionFFN โ 180 K โ train โ 0 โ โ 6 โ bond_constrainer โ ConstrainerFFN โ 180 K โ train โ 0 โ โ 7 โ bns โ ModuleList โ 0 โ train โ 0 โ โ 8 โ X_d_transform โ Identity โ 0 โ train โ 0 โ โ 9 โ metricss โ ModuleList โ 0 โ train โ 0 โ โโโโโดโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโดโโโโโโโโ
Trainable params: 956 K Non-trainable params: 0 Total params: 956 K Total estimated model params size (MB): 3 Modules in train mode: 72 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.
[17]:
results = trainer.test(dataloaders=test_dataloader, weights_only=False) # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
Restoring states from the checkpoint path at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20-v2.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20-v2.ckpt
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ Test metric โ DataLoader 0 โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ atom_test/mse โ 3.116415500640869 โ โ bond_test/mse โ 12.054455757141113 โ โ mol_test/mse โ 0.16347543895244598 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[18]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[19]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]
atom_preds = torch.split(atom_preds, atoms_per_mol)
bond_preds = torch.split(bond_preds, bonds_per_mol)