Encoding fingerprint latent representation#
Import packages#
[1]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.decomposition import PCA
from pathlib import Path
from chemprop import data, featurizers, models
Change model input here#
[2]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests/data/example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook,
# it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.
Load model#
[3]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn.eval()
mpnn
[3]:
MPNN(
(message_passing): BondMessagePassing(
(W_i): Linear(in_features=86, out_features=300, bias=False)
(W_h): Linear(in_features=300, out_features=300, bias=False)
(W_o): Linear(in_features=372, out_features=300, bias=True)
(W_d): Linear(in_features=300, out_features=300, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(tau): ReLU()
)
(agg): MeanAggregation()
(bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(predictor): RegressionFFN(
(ffn): MLP(
(0): Sequential(
(0): Linear(in_features=300, out_features=300, bias=True)
)
(1): Sequential(
(0): ReLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=300, out_features=1, bias=True)
)
)
(criterion): MSELoss()
)
)
Change data input here#
[4]:
test_path = '../tests/data/smis.csv'
smiles_column = 'smiles'
Load data#
[5]:
df_test = pd.read_csv(test_path)
smis = df_test[smiles_column]
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]
test_data[:5]
[5]:
[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x2879f03c0>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14', V_f=None, E_f=None, V_d=None),
MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x2879f04a0>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23', V_f=None, E_f=None, V_d=None),
MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x2879f0580>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl', V_f=None, E_f=None, V_d=None),
MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x2879f0660>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3', V_f=None, E_f=None, V_d=None),
MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x2879f0740>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1', V_f=None, E_f=None, V_d=None)]
Get featurizer#
[6]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
Get datasets#
[7]:
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)
test_loader = data.build_dataloader(test_dset, shuffle=False)
Calculate fingerprints#
models.MPNN.encoding(inputs : BatchMolGraph, i : int)
calculate the i-th hidden representation.
i
ia the stop index of slice of the MLP used to encode the input. That is, use all layers in the MLP up to :attr:i
(i.e., MLP[:i]
). This can be any integer value, and the behavior of this function is dependent on the underlying list slicing behavior. For example:
i=0
: use a 0-layer MLP (i.e., a no-op)i=1
: use only the first blocki=-1
: use up to the second-to-last block
[8]:
with torch.no_grad():
fingerprints = [
mpnn.encoding(batch.bmg, batch.V_d, batch.X_d, i=0)
for batch in test_loader
]
fingerprints = torch.cat(fingerprints, 0)
fingerprints.shape
[8]:
torch.Size([100, 300])
[9]:
with torch.no_grad():
encodings = [
mpnn.encoding(batch.bmg, batch.V_d, batch.X_d, i=1)
for batch in test_loader
]
encodings = torch.cat(encodings, 0)
encodings.shape
[9]:
torch.Size([100, 300])
Using fingerprints#
[10]:
fingerprints = fingerprints.detach()
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(fingerprints)
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 1, 1)
ax.set_title("Fingerprints")
ax.set_xlabel('PCA1'); ax.set_ylabel('PCA2')
ax.scatter(principalComponents[:, 0], principalComponents[:, 1])
plt.show()
[11]:
encodings = encodings.detach()
pca = PCA(n_components=2)
principalComponents = pca.fit_transform(encodings)
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 1, 1)
ax.set_title("Encodings")
ax.set_xlabel('PCA1'); ax.set_ylabel('PCA2')
ax.scatter(principalComponents[:, 0], principalComponents[:, 1])
plt.show()