Atom and Bond Prediction#

Open In Colab

[1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples
[2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

chemprop_dir = Path.cwd().parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

This notebook shows how to use Chemprop to fit models on atom and bond property data. One model can predict molecule-, atom-, and bond-level properties at the same time.

Make datapoints#

The atom and bond targets are saved as strings that look like lists. This example uses regression targets, but classification (including multiclass) is also supported.

[3]:
df_input = pd.read_csv(data_dir / "regression.csv")
df_input
[3]:
smiles mol_y1 mol_y2 atom_y1 atom_y2 bond_y1 bond_y2 weight
0 [H][H] 2.016 2.0 [1, 1] [1.008, 1.008] [2] [-2] 0.090909
1 C 16.043 1.0 [6] [12.011] [] [] 0.181818
2 CN 31.058 2.0 [6, 7] [12.011, 14.007] [13] [-13] 0.272727
3 CN 31.058 NaN [6, 7] [None, 14.007] [13] [None] 0.363636
4 CC 30.070 2.0 [6, 6] [12.011, 12.011] [12] [-12] 0.454545
5 [CH2:3]=[N+:1]([H:4])[H:2] 30.050 4.0 [7, 1, 6, 1] [14.007, 1.008, 12.011, 1.008] [13, 8, 8] [-13, -8, -8] 0.545455
6 CCCC 58.124 4.0 [6, 6, 6, 6] [12.011, 12.011, 12.011, 12.011] [12, 12, 12] [-12, -12, -12] 0.636364
7 CO 32.042 2.0 [6, 8] [12.011, 15.999] [14] [-14] 0.727273
8 CC#N 41.053 3.0 [6, 6, 7] [12.011, 12.011, 14.007] [12, 13] [-12, -13] 0.818182
9 C1NN1 44.057 3.0 [6, 7, 7] [12.011, 14.007, 14.007] [13, 14, 13] [-13, -14, -13] 0.909091
10 c1cc[n-]c1 66.083 5.0 [6, 6, 6, 7, 6] [12.011, 12.011, 12.011, 14.007, 12.011] [12, 12, 13, 13, 12] [-12, -12, -13, -13, -12] 1.000000

Load optional extra features and descriptors#

Extra bond descriptors can be used when making bond property predictions, analogous to extra atom descriptors.

[4]:
x_ds = np.load(data_dir / "descriptors.npz")["arr_0"]
V_fs = np.load(data_dir / "atom_features_descriptors.npz")
V_fs = [V_fs[f"arr_{i}"] for i in range(len(V_fs))]
V_ds = V_fs
E_fs = np.load(data_dir / "bond_features_descriptors.npz")
E_fs = [E_fs[f"arr_{i}"] for i in range(len(E_fs))]
E_ds = [np.repeat(E_f, repeats=2, axis=0) for E_f in E_fs]
[5]:
columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]].values
atoms_ys = df_input.loc[:, columns[3:5]].values
bonds_ys = df_input.loc[:, columns[5:7]].values
weights = df_input.loc[:, columns[7]].values

# String lists are converted to lists using ast.literal_eval
atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        # If the atom targets follow the order of an atom mapping in the SMILES string instead of
        # the order of the atoms in the SMILES string (i.e. [F:2][Cl:1]), set reorder_atoms=True.
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
        x_d=x_ds[i],
        V_f=V_fs[i],
        V_d=V_ds[i],
        E_f=E_fs[i],
        E_d=E_ds[i],
    )
    for i, smi in enumerate(smis)
]

If the regression targets are bounded (i.e. look like “<3” or “>0.1”), parsing the atom and bond targets is a bit more complicated. Note that BoundedMSE should be used as the loss function (RegressionFFN(criterion=BoundedMSE)) and the less-than and greater-than masks should be given to the datapoints.

[6]:
bounded = False
if bounded:
    mol_ys = mol_ys.astype(str)
    lt_mask = mol_ys.map(lambda x: "<" in x).to_numpy()
    gt_mask = mol_ys.map(lambda x: ">" in x).to_numpy()
    mol_ys = mol_ys.map(lambda x: x.strip("<").strip(">")).to_numpy(np.single)

    atoms_ys = atoms_ys.map(ast.literal_eval)
    atom_lt_masks = atoms_ys.map(lambda L: ["<" in v if v else False for v in L])
    atom_gt_masks = atoms_ys.map(lambda L: [">" in v if v else False for v in L])

    atom_lt_masks = atom_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    atom_gt_masks = atom_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    atoms_ys = atoms_ys.map(
        lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
    )
    atoms_ys = atoms_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    bonds_ys = bonds_ys.map(ast.literal_eval)
    bond_lt_masks = bonds_ys.map(lambda L: ["<" in v if v else False for v in L])
    bond_gt_masks = bonds_ys.map(lambda L: [">" in v if v else False for v in L])

    bond_lt_masks = bond_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    bond_gt_masks = bond_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    bond_lt_masks = [bond_lt_mask.astype(bool) for bond_lt_mask in bond_lt_masks]
    bond_gt_masks = [bond_gt_mask.astype(bool) for bond_gt_mask in bond_gt_masks]

    bonds_ys = bonds_ys.map(
        lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
    )
    bonds_ys = bonds_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    datapoints = [
        data.MolAtomBondDatapoint.from_smi(
            smi,
            keep_h=True,
            add_h=False,
            reorder_atoms=True,
            y=mol_ys[i],
            atom_y=atoms_ys[i],
            bond_y=bonds_ys[i],
            weight=weights[i],
            lt_mask=lt_mask[i],
            gt_mask=gt_mask[i],
            atom_lt_mask=atom_lt_masks[i],
            atom_gt_mask=atom_gt_masks[i],
            bond_lt_mask=bond_lt_masks[i],
            bond_gt_mask=bond_gt_masks[i],
        )
        for i, smi in enumerate(smis)
    ]

Make datasets#

[7]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(
    extra_atom_fdim=V_fs[0].shape[1], extra_bond_fdim=E_fs[0].shape[1]
)

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

Scale the extra features and descriptors#

If extra features and descriptors are used, they can be scaled to make training easier. The scalers are turned into “transforms” which are given to the model to use at inference time.

[8]:
V_f_scaler = train_dataset.normalize_inputs("V_f")
E_f_scaler = train_dataset.normalize_inputs("E_f")
V_d_scaler = train_dataset.normalize_inputs("V_d")
E_d_scaler = train_dataset.normalize_inputs("E_d")
val_dataset.normalize_inputs("V_f", V_f_scaler)
val_dataset.normalize_inputs("E_f", E_f_scaler)
val_dataset.normalize_inputs("V_d", V_d_scaler)
val_dataset.normalize_inputs("E_d", E_d_scaler)

V_f_transform = nn.ScaleTransform.from_standard_scaler(
    V_f_scaler, pad=(featurizer.atom_fdim - featurizer.extra_atom_fdim)
)
E_f_transform = nn.ScaleTransform.from_standard_scaler(
    E_f_scaler, pad=(featurizer.bond_fdim - featurizer.extra_bond_fdim)
)
graph_transform = nn.GraphTransform(V_f_transform, E_f_transform)

V_d_transform = nn.ScaleTransform.from_standard_scaler(V_d_scaler)
E_d_transform = nn.ScaleTransform.from_standard_scaler(E_d_scaler)

X_d_scaler = train_dataset.normalize_inputs("X_d")
val_dataset.normalize_inputs("X_d", X_d_scaler)
X_d_transform = nn.ScaleTransform.from_standard_scaler(X_d_scaler)

Scale the regression targets#

[9]:
mol_target_scaler = train_dataset.normalize_targets("mol")
atom_target_scaler = train_dataset.normalize_targets("atom")
bond_target_scaler = train_dataset.normalize_targets("bond")
val_dataset.normalize_targets("mol", mol_target_scaler)
val_dataset.normalize_targets("atom", atom_target_scaler)
val_dataset.normalize_targets("bond", bond_target_scaler)
mol_target_transform = nn.UnscaleTransform.from_standard_scaler(mol_target_scaler)
atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)
bond_target_transform = nn.UnscaleTransform.from_standard_scaler(bond_target_scaler)

Make dataloaders#

[10]:
train_dataloader = data.build_dataloader(train_dataset, shuffle=True, batch_size=4)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False, batch_size=4)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False, batch_size=4)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False, batch_size=4)

The MAB (mol atom bond) message passing returns both learned node embeddings and learned edge embeddings#

MABBondMessagePassing takes the same customization arguments as the usual BondMessagePassing class

[11]:
mp = nn.MABBondMessagePassing(
    d_v=featurizer.atom_fdim,
    d_e=featurizer.bond_fdim,
    d_h=100,
    d_vd=V_ds[0].shape[1],
    d_ed=E_ds[0].shape[1],
    dropout=0.1,
    activation="tanh",
    depth=4,
    graph_transform=graph_transform,
    V_d_transform=V_d_transform,
    E_d_transform=E_d_transform,
)

A separate predictor is used for each of the molecule, atom, and bond predictions#

[12]:
agg = nn.MeanAggregation()

# Note that each predictor may have a different input dimension
mol_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0] + x_ds.shape[1],
    n_tasks=mol_ys.shape[1],
    output_transform=mol_target_transform,
)
atom_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0],
    n_tasks=atoms_ys[0].shape[1],
    output_transform=atom_target_transform,
)
bond_predictor = nn.RegressionFFN(
    input_dim=(mp.output_dims[1] * 2),
    n_tasks=bonds_ys[0].shape[1],
    output_transform=bond_target_transform,
)

Different predictors can be used for different types of tasks including but not limited to MveFFN, BinaryClassificationFFN, MulticlassClassificationFFN.

Combine the layers into a single model#

[13]:
metrics = [nn.MAE(), nn.RMSE()]
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    batch_norm=True,
    metrics=metrics,
    X_d_transform=X_d_transform,
)
[14]:
model
[14]:
MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=102, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=204, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bns): ModuleList(
    (0-2): 3 x BatchNorm1d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (X_d_transform): ScaleTransform()
  (metricss): ModuleList(
    (0-2): 3 x ModuleList(
      (0): MAE(task_weights=[[1.0]])
      (1): RMSE(task_weights=[[1.0]])
      (2): MSE(task_weights=[[1.0, 1.0]])
    )
  )
)

If any of molecule, atom, or bond targets are not used, the corresponding predictor isn’t added to the model. If bond targets are not used, the message passing layer should be told to not return the bond embeddings to avoid initializing weight matrices that won’t be used. If molecule targets are not used, the aggregation layer isn’t added to the model. If both molecule and atom targets are not used, the message passing layer should be told not to return the node embeddings.

[15]:
no_bond = False
no_mol = False
no_mol_atom = False

if no_bond:
    mp = nn.MABBondMessagePassing(return_edge_embeddings=False)
    agg = nn.NormAggregation()
    mol_predictor = nn.RegressionFFN()
    atom_predictor = nn.RegressionFFN()
    model = models.MolAtomBondMPNN(
        message_passing=mp, agg=agg, mol_predictor=mol_predictor, atom_predictor=atom_predictor
    )

if no_mol:
    mp = nn.MABBondMessagePassing()
    atom_predictor = nn.RegressionFFN()
    bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2))
    model = models.MolAtomBondMPNN(
        message_passing=mp, atom_predictor=atom_predictor, bond_predictor=bond_predictor
    )

if no_mol_atom:
    mp = nn.MABBondMessagePassing(return_vertex_embeddings=False)
    bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2))
    model = models.MolAtomBondMPNN(message_passing=mp, bond_predictor=bond_predictor)

Set up trainer with checkpointing#

[16]:
checkpointing = ModelCheckpoint(
    dirpath="MABcheckpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[17]:
trainer.fit(model, train_dataloader, val_dataloader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/MABcheckpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name             Type                   Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ message_passing │ MABBondMessagePassing │ 69.2 K │ train │     0 │
│ 1 │ agg             │ MeanAggregation       │      0 │ train │     0 │
│ 2 │ mol_predictor   │ RegressionFFN         │ 32.1 K │ train │     0 │
│ 3 │ atom_predictor  │ RegressionFFN         │ 31.5 K │ train │     0 │
│ 4 │ bond_predictor  │ RegressionFFN         │ 62.1 K │ train │     0 │
│ 5 │ bns             │ ModuleList            │    612 │ train │     0 │
│ 6 │ X_d_transform   │ ScaleTransform        │      0 │ train │     0 │
│ 7 │ metricss        │ ModuleList            │      0 │ train │     0 │
└───┴─────────────────┴───────────────────────┴────────┴───────┴───────┘
Trainable params: 195 K
Non-trainable params: 0
Total params: 195 K
Total estimated model params size (MB): 0
Modules in train mode: 63
Modules in eval mode: 0
Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec
tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the
value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=20` reached.
[18]:
results = trainer.test(dataloaders=test_dataloader, weights_only=False)  # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
Restoring states from the checkpoint path at /home/knathan/chemprop/examples/MABcheckpoints/best-epoch=17-val_loss=0.13.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/MABcheckpoints/best-epoch=17-val_loss=0.13.ckpt
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       atom_test/mae           0.40444326400756836    │
│      atom_test/rmse           0.8327591419219971     │
│       bond_test/mae           0.5269983410835266     │
│      bond_test/rmse            1.082360863685608     │
│       mol_test/mae             3.099299907684326     │
│       mol_test/rmse            5.333563804626465     │
└───────────────────────────┴───────────────────────────┘
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[19]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

Split the atom and bond predictions into a list of tensors, one for each molecule#

[20]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]

atom_preds = torch.split(atom_preds, atoms_per_mol)
bond_preds = torch.split(bond_preds, bonds_per_mol)

Save and load the model#

[21]:
models.utils.save_model("temp.pt", model)
models.MolAtomBondMPNN.load_from_file("temp.pt")
[21]:
MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=102, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=204, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bns): ModuleList(
    (0-2): 3 x BatchNorm1d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (X_d_transform): ScaleTransform()
  (metricss): ModuleList(
    (0-2): 3 x ModuleList(
      (0): MAE(task_weights=[[1.0]])
      (1): RMSE(task_weights=[[1.0]])
      (2): MSE(task_weights=[[1.0, 1.0]])
    )
  )
)
[22]:
models.MolAtomBondMPNN.load_from_checkpoint("MABcheckpoints/last.ckpt")
[22]:
MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=102, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=204, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): UnscaleTransform()
  )
  (bns): ModuleList(
    (0-2): 3 x BatchNorm1d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (X_d_transform): ScaleTransform()
  (metricss): ModuleList(
    (0-2): 3 x ModuleList(
      (0): MAE(task_weights=[[1.0]])
      (1): RMSE(task_weights=[[1.0]])
      (2): MSE(task_weights=[[1.0, 1.0]])
    )
  )
)