Multitask model#
[1]:
# Install chemprop from GitHub if running in Google Colab
import os
if os.getenv("COLAB_RELEASE_TAG"):
try:
import chemprop
except ImportError:
!git clone https://github.com/chemprop/chemprop.git
%cd chemprop
!pip install .
%cd examples
[2]:
from lightning import pytorch as pl
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from chemprop import data, models, nn
Step 1: Make datapoints
[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol_multitask.csv"
smiles_column = 'smiles'
target_columns = ["mu","alpha","homo","lumo","gap","r2","zpve","cv","u0","u298","h298","g298"]
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
datapoints = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
Step 2: Split data and make datasets
[4]:
split_indices = data.make_split_indices(datapoints)
train_data, val_data, test_data = data.split_data_by_indices(datapoints, *split_indices)
train_dset = data.MoleculeDataset(train_data[0])
val_dset = data.MoleculeDataset(val_data[0])
test_dset = data.MoleculeDataset(test_data[0])
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
Step 3: Scale targets and make dataloaders
[5]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)
train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset)
test_loader = data.build_dataloader(test_dset)
Step 4: Define the model
[6]:
output_transform = nn.transforms.UnscaleTransform.from_standard_scaler(output_scaler)
ffn = nn.RegressionFFN(n_tasks = len(target_columns), output_transform=output_transform)
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), ffn)
Step 5: Set up the trainer
[7]:
trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Step 6: Train the model
[8]:
trainer.fit(chemprop_model, train_loader, val_loader)
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃ ┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩ │ 0 │ message_passing │ BondMessagePassing │ 227 K │ train │ 0 │ │ 1 │ agg │ MeanAggregation │ 0 │ train │ 0 │ │ 2 │ bn │ Identity │ 0 │ train │ 0 │ │ 3 │ predictor │ RegressionFFN │ 93.9 K │ train │ 0 │ │ 4 │ X_d_transform │ Identity │ 0 │ train │ 0 │ │ 5 │ metrics │ ModuleList │ 0 │ train │ 0 │ └───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 321 K Non-trainable params: 0 Total params: 321 K Total estimated model params size (MB): 1 Modules in train mode: 24 Modules in eval mode: 0 Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:485: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=1` reached.
Step 7: Use the model to make predictions
[9]:
preds = trainer.predict(chemprop_model, test_loader)
preds = torch.concat(preds, axis=1)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:485: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[10]:
preds.shape
[10]:
torch.Size([51, 12])