CheMeleon Foundation Finetuning#
This notebook demonstrates how to use the CheMeleon foundation model with Chemprop to achieve accurate prediction on small datasets. One can also use this functionality from the Command Line Interface by using --from-foundation chemeleon.
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Retrieving the CheMeleon Model#
The CheMeleon model file is stored on Zenodo at this link. Please cite the Zenodo if you use this model in published work. You can manually download for your own use, or simply execute the below cell to programatically download it using Python:
[2]:
from urllib.request import urlretrieve
urlretrieve(
r"https://zenodo.org/records/15460715/files/chemeleon_mp.pt",
"chemeleon_mp.pt",
)
[2]:
('chemeleon_mp.pt', <http.client.HTTPMessage at 0x75ba07670b10>)
Initializing CheMeleon#
CheMeleon uses the following classes for featurization, message passing, and aggregation:
[3]:
import torch
from chemprop import featurizers, nn
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
agg = nn.MeanAggregation()
chemeleon_mp = torch.load("chemeleon_mp.pt", weights_only=True)
mp = nn.BondMessagePassing(**chemeleon_mp['hyper_parameters'])
mp.load_state_dict(chemeleon_mp['state_dict'])
[3]:
<All keys matched successfully>
If you have existing Chemprop training code you can simply replace your agg, featurizer, and mp with these classes and you can immediately take advantage of CheMeleon!
In general, we suggest continuing to train the CheMeleon weights during finetuning. You may find that in some cases, freezing the weights (mp.eval(), mp.apply(lambda module: module.requires_grad_(False))) may improve performance.
Standard Chemprop Preparation#
The below code handles importing needed modules, setting up the data, and initializing the Chemprop model. It’s mostly the same as the training example provided in the Chemprop repository - for a more detailed breakdown, see the docs here.
The one important change is that we must set input_dim=mp.output_dim when we initialize our FFN. This ensures that the dimension of the learned representation from CheMeleon matches the input size for the regressor. Also important to note here is that to make the CheMeleon model useful you set up your own FFN to regress the target you care about - in this case lipophilicity.
You can also use CheMeleon for classification tasks. See the regular classification demo notebook, which can be modified as shown above to load CheMeleon.
[4]:
from pathlib import Path
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
from chemprop import data, models
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1)) # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()
val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=mp.output_dim)
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]
mpnn = models.MPNN(mp, agg, ffn, batch_norm=False, metrics=metric_list)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
Now we can take a look at the model, which we can see has the huge message passing setup from CheMeleon:
[5]:
mpnn
[5]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=2048, bias=False)
(W_h): Linear(in_features=2048, out_features=2048, bias=False)
(W_o): Linear(in_features=2120, out_features=2048, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): Identity()
)
(agg): MeanAggregation()
(bn): Identity()
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=2048, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0): RMSE(task_weights=[[1.0]])
(1): MAE(task_weights=[[1.0]])
(2): MSE(task_weights=[[1.0]])
)
)
Training#
The remainder of this notebook again follows the typical training routine. With the addition of CheMeleon your model may take longer to complete a single epoch due to the increased number of parameters but will (hopefully!) have better performance, particularly if the dataset you have is small, and require fewer epochs to converge!
[6]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
"checkpoints", # Directory where model checkpoints will be saved
"best-{epoch}-{val_loss:.2f}", # Filename format for checkpoints, including epoch and validation loss
"val_loss", # Metric used to select the best checkpoint (based on validation loss)
mode="min", # Save the checkpoint with the lowest validation loss (minimization objective)
save_last=True, # Always save the most recent checkpoint, even if it's not the best
)
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=2, # Set to 2 for demonstration. Adjust as needed for actual training.
callbacks=[checkpointing], # Use the configured checkpoint callback
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[7]:
trainer.fit(mpnn, train_loader, val_loader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ ┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ │ 0 │ message_passing │ BondMessagePassing │ 8.7 M │ train │ 0 │ │ 1 │ agg │ MeanAggregation │ 0 │ train │ 0 │ │ 2 │ bn │ Identity │ 0 │ train │ 0 │ │ 3 │ predictor │ RegressionFFN │ 615 K │ train │ 0 │ │ 4 │ X_d_transform │ Identity │ 0 │ train │ 0 │ │ 5 │ metrics │ ModuleList │ 0 │ train │ 0 │ └───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 9.3 M Non-trainable params: 0 Total params: 9.3 M Total estimated model params size (MB): 37 Modules in train mode: 25 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=2` reached.
[8]:
results = trainer.test(dataloaders=test_loader, weights_only=False) # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
Restoring states from the checkpoint path at /home/knathan/chemprop/examples/checkpoints/best-epoch=1-val_loss=0.84.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/checkpoints/best-epoch=1-val_loss=0.84.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test/mae │ 0.6668714284896851 │ │ test/rmse │ 0.9082838892936707 │ └───────────────────────────┴───────────────────────────┘