Training with NLogProbEnrichment

Training with NLogProbEnrichment#

This notebook demonstrates how to use the loss function described in Lim et al. (2022) JCIM for use on Poisson distributed (or negative binomial distributed) count data e.g. DNA-encoded library screening data.

Notable differences in how this type of model is setup compared to typical Chemprop training:

  • this loss function requires two target columns, “postive” and “negative”. Both must be count (int) values

  • do not use scaling, as the loss function takes the raw counts

  • output transform for the FNN must be set to SoftPlus

  • the NLogProbEnrichment metric must be used

This notebook is adapted from the regular training demo notebook, which may be helpful as a reference.

Open In Colab

[1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

Initial Setup#

We’ll follow the typical procedure for importing the necessary packages and defining some overall settings related to the data.

[2]:
from pathlib import Path

import torch
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from chemprop import data, featurizers, models, nn
[3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
num_workers = 0
smiles_column = 'smiles'

The first big difference between NLogProbEnrichment training and convention Chemprop training is that the target columns must be exactly two: the count of positive and count of negative samples. For this demo these columns aren’t actually in the dataset, we’ll just randomly generate the data for demonstration purposes.

[4]:
target_columns = ['counts_pos', 'counts_neg']
[5]:
df_input = pd.read_csv(input_path)


# creating some random count data for the NLogProbEnrichment metric
df_input['counts_pos'] = np.random.poisson(lam=6, size= df_input.shape[0])
df_input['counts_neg'] = np.random.poisson(lam=4, size= df_input.shape[0])

total_counts_pos= int(df_input['counts_pos'].sum())  # total number of positive samples
total_counts_neg = int(df_input['counts_neg'].sum())  # total number of negative samples

print(f"Total counts (positive samples): {total_counts_pos}")
print(f"Total counts (negative samples): {total_counts_neg}")

df_input
Total counts (positive samples): 624
Total counts (negative samples): 417
[5]:
smiles lipo counts_pos counts_neg
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54 8 5
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18 3 3
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69 8 2
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37 6 3
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10 6 8
... ... ... ... ...
95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20 3 2
96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04 4 3
97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49 6 9
98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20 3 1
99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00 11 1

100 rows × 4 columns

[6]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

Training Prepation#

Now we follow the typical procedure to set up our data and neural network, with just a few small changes to faciliate NLogProbEnrichment loss.

[7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
[8]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)

When creating the MoleculeDataset class one would often rescale the target variables like this:

scaler = train_dset.normalize_targets()
val_dset.normalize_targets(scaler)

We do NOT do this here, since the loss function operates on the counts directly.

[9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
[10]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
[11]:
mp = nn.BondMessagePassing()
[12]:
agg = nn.MeanAggregation()

The output of our FFN must go through the Softplus activation function, which we will use as our output_transform.

[13]:
output_transform = torch.nn.Softplus()

We’ll then initialize the actual loss function:

[14]:
criterion = nn.metrics.NLogProbEnrichment(
    n1=total_counts_pos,
    n2=total_counts_neg,
    method="sqrt",
    zscale=1.0,
    zinterval=5.0,
)

And finally build the FFN (note that we pass our Softplus as the output_transform - this is the only way to train with NLogProbEnrichment loss):

[15]:
ffn = nn.predictors.RegressionFFN(criterion=criterion, output_transform=output_transform)
[16]:
batch_norm = True
[17]:
metric_list = [nn.metrics.NLogProbEnrichment(n1=total_counts_pos, n2=total_counts_neg)]
[18]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn
[18]:
MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): NLogProbEnrichment(n1=624, n2=417, method='sqrt', zscale=1.0, zinterval=5)
    (output_transform): Softplus(beta=1.0, threshold=20.0)
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): NLogProbEnrichment(n1=624, n2=417, method='sqrt', zscale=1.0, zinterval=5.0)
    (1): NLogProbEnrichment(n1=624, n2=417, method='sqrt', zscale=1.0, zinterval=5)
  )
)

Training and Inference#

[19]:
checkpointing = ModelCheckpoint(
    "checkpoints",
    "best-{epoch}-{val_loss:.2f}",
    "val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[20]:
trainer.fit(mpnn, train_loader, val_loader)
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /home/jackson/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | train
3 | predictor       | RegressionFFN      | 90.6 K | train
4 | X_d_transform   | Identity           | 0      | train
5 | metrics         | ModuleList         | 0      | train
---------------------------------------------------------------
318 K     Trainable params
0         Non-trainable params
318 K     Total params
1.276     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 32.95it/s, train_loss_step=0.449, val_loss=0.632, train_loss_epoch=0.282]
`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 18.67it/s, train_loss_step=0.449, val_loss=0.632, train_loss_epoch=0.282]
[ ]:
results = trainer.test(dataloaders=test_loader, weights_only=False)  # weights_only=False is only required for lighting 2.6+
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
Restoring states from the checkpoint path at /home/jackson/chemprop/examples/checkpoints/best-epoch=19-val_loss=0.63.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jackson/chemprop/examples/checkpoints/best-epoch=19-val_loss=0.63.ckpt
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 262.47it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test/nlogprob_enrichment      0.5011885762214661     │
└───────────────────────────┴───────────────────────────┘