Multitask model

Multitask model#

Open In Colab

[1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples
[2]:
from lightning import pytorch as pl
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from chemprop import data, models, nn

Step 1: Make datapoints

[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol_multitask.csv"
smiles_column = 'smiles'
target_columns = ["mu","alpha","homo","lumo","gap","r2","zpve","cv","u0","u298","h298","g298"]

df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

datapoints = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Step 2: Split data and make datasets

[4]:
split_indices = data.make_split_indices(datapoints)
train_data, val_data, test_data = data.split_data_by_indices(datapoints, *split_indices)


train_dset = data.MoleculeDataset(train_data[0])
val_dset = data.MoleculeDataset(val_data[0])
test_dset = data.MoleculeDataset(test_data[0])
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)

Step 3: Scale targets and make dataloaders

[5]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)

train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset)
test_loader = data.build_dataloader(test_dset)

Step 4: Define the model

[6]:
output_transform = nn.transforms.UnscaleTransform.from_standard_scaler(output_scaler)

ffn = nn.RegressionFFN(n_tasks = len(target_columns), output_transform=output_transform)
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), ffn)

Step 5: Set up the trainer

[7]:
trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

Step 6: Train the model

[8]:
trainer.fit(chemprop_model, train_loader, val_loader)
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name             Type                Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ message_passing │ BondMessagePassing │  227 K │ train │     0 │
│ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
│ 2 │ bn              │ Identity           │      0 │ train │     0 │
│ 3 │ predictor       │ RegressionFFN      │ 93.9 K │ train │     0 │
│ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
│ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
└───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 321 K
Non-trainable params: 0
Total params: 321 K
Total estimated model params size (MB): 1
Modules in train mode: 24
Modules in eval mode: 0
Total FLOPs: 0
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec
tor.py:485: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn
shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connec
tor.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the
value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=1` reached.

Step 7: Use the model to make predictions

[9]:
preds = trainer.predict(chemprop_model, test_loader)
preds = torch.concat(preds, axis=1)
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:485: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[10]:
preds.shape
[10]:
torch.Size([51, 12])