Transfer Learning / Pretraining#
Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook.
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
Import packages#
[2]:
import pandas as pd
from pathlib import Path
from lightning import pytorch as pl
from sklearn.preprocessing import StandardScaler
import torch
from chemprop import data, featurizers, models
Change data inputs here#
[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets
Load data#
[4]:
df_input = pd.read_csv(input_path)
df_input
[4]:
| smiles | lipo | |
|---|---|---|
| 0 | Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | 3.54 |
| 1 | COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | -1.18 |
| 2 | COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | 3.69 |
| 3 | OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | 3.37 |
| 4 | Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | 3.10 |
| ... | ... | ... |
| 95 | CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | 2.20 |
| 96 | CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | 2.04 |
| 97 | CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | 4.49 |
| 98 | COc1ccc(Cc2c(N)n[nH]c2N)cc1 | 0.20 |
| 99 | CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | 2.00 |
100 rows Γ 2 columns
Get SMILES and targets#
[5]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
[6]:
smis[:5] # show first 5 SMILES strings
[6]:
array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
dtype=object)
[7]:
ys[:5] # show first 5 targets
[7]:
array([[ 3.54],
[-1.18],
[ 3.69],
[ 3.37],
[ 3.1 ]])
Get molecule datapoints#
[8]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
Perform data splitting for training, validation, and testing#
[9]:
# available split types
list(data.SplitType.keys())
[9]:
['SCAFFOLD_BALANCED',
'RANDOM_WITH_REPEATED_SMILES',
'RANDOM',
'KENNARD_STONE',
'KMEANS']
[10]:
mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
Change checkpoint model inputs here#
Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file.
[11]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.
[12]:
mpnn_cls = models.MPNN
[13]:
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn
[13]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
Scale fine-tuning data with the modelβs target scaler#
If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler.
[14]:
pretraining_scaler = StandardScaler()
pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()
Get MoleculeDataset#
[15]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data[0], featurizer)
train_dset.normalize_targets(pretraining_scaler)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(pretraining_scaler)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
Get DataLoader#
[16]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
Freezing MPNN and FFN layers#
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.
Freezing the MPNN#
[17]:
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval() # Set batch norm layers to eval mode to freeze running mean and running var.
[17]:
BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Freezing FFN layers#
[18]:
frzn_ffn_layers = 1 # the number of consecutive FFN layers to freeze.
[19]:
for idx in range(frzn_ffn_layers):
mpnn.predictor.ffn[idx].requires_grad_(False)
mpnn.predictor.ffn[idx + 1].eval()
Set up trainer#
[20]:
trainer = pl.Trainer(
logger=False,
enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
enable_progress_bar=True,
accelerator="auto",
devices=1,
max_epochs=20, # number of epochs to train for
)
π‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Start training#
[21]:
trainer.fit(mpnn, train_loader, val_loader)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
βββββ³ββββββββββββββββββ³βββββββββββββββββββββ³βββββββββ³ββββββββ³ββββββββ β β Name β Type β Params β Mode β FLOPs β β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β 0 β message_passing β BondMessagePassing β 227 K β eval β 0 β β 1 β agg β MeanAggregation β 0 β train β 0 β β 2 β bn β BatchNorm1d β 600 β eval β 0 β β 3 β predictor β RegressionFFN β 90.6 K β train β 0 β β 4 β X_d_transform β Identity β 0 β train β 0 β β 5 β metrics β ModuleList β 0 β train β 0 β βββββ΄ββββββββββββββββββ΄βββββββββββββββββββββ΄βββββββββ΄ββββββββ΄ββββββββ
Trainable params: 301 Non-trainable params: 318 K Total params: 318 K Total estimated model params size (MB): 1 Modules in train mode: 11 Modules in eval mode: 15 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:534: Found 15 module(s) in eval mode at the start of training. This may lead to unexpected behavior during training. If this is intentional, you can ignore this warning.
`Trainer.fit` stopped: `max_epochs=20` reached.
Test results#
[22]:
results = trainer.test(mpnn, test_loader, weights_only=False) # weights_only=False is only required with pytorch lightning version 2.6.0 or newer
βββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ β Test metric β DataLoader 0 β β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β test/mse β 0.9767233729362488 β βββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
Transfer learning with multicomponenent models#
Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning.
Change data inputs here#
[23]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt" # path to the checkpoint file.
Change checkpoint model inputs here#
[24]:
mpnn_cls = models.MulticomponentMPNN
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
mcmpnn
[24]:
MulticomponentMPNN(
(message_passing): MulticomponentMessagePassing(
(blocks): ModuleList(
(0-1): 2 x BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
(V_d_transform): Identity()
(graph_transform): GraphTransform(
(V_transform): Identity()
(E_transform): Identity()
)
)
)
)
(agg): MeanAggregation()
(bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=600, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSE(task_weights=[[1.0]])
(output_transform): UnscaleTransform()
)
(X_d_transform): Identity()
(metrics): ModuleList(
(0-1): 2 x MSE(task_weights=[[1.0]])
)
)
[25]:
blocks_to_freeze = [0, 1] # a list of indices of the individual MPNN blocks to freeze before training.
[26]:
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
for i in blocks_to_freeze:
mp_block = mcmpnn.message_passing.blocks[i]
mp_block.apply(lambda module: module.requires_grad_(False))
mp_block.eval()
mcmpnn.bn.apply(lambda module: module.requires_grad_(False))
mcmpnn.bn.eval()
[26]:
BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)