Running interpretation#

[15]:
# MCTS options
rollout = 10  # number of MCTS rollouts to perform. If mol.GetNumAtoms() > 50, consider setting n_rollout = 1 to avoid long computation time

c_puct = 10.0  # constant that controls the level of exploration

max_atoms = 20  # maximum number of atoms allowed in an extracted rationale

min_atoms = 8  # minimum number of atoms in an extracted rationale

prop_delta = 0.5  # Minimum score to count as positive.
# In this algorithm, if the predicted property from the substructure if larger than prop_delta, the substructure is considered satisfactory.
# This value depends on the property you want to interpret. 0.5 is a dummy value for demonstration purposes

num_rationales_to_keep = 5  # number of rationales to keep for each molecule
[16]:
# Define the scoring function. "Score" for a substructure is the predicted property value of the substructure.

models = [mpnn]

property_for_interpretation = "lipo"

property_id = (
    df_test.columns.get_loc(property_for_interpretation) - 1
)  # property index in the dataset; -1 for the SMILES column


def scoring_function(smiles: list[str]) -> list[float]:
    return make_prediction(
        models=models,
        trainer=trainer,
        smiles=smiles,
    )[:, property_id]
[17]:
# only use the first 5 SMILES for demonstration purposes
all_smiles = df_test[smiles_column].tolist()[:5]
all_smiles
[17]:
['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
 'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
 'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
 'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
 'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1']
[18]:
%%time

results_df = {"smiles": [], property_for_interpretation: []}

for i in range(num_rationales_to_keep):
    results_df[f"rationale_{i}"] = []
    results_df[f"rationale_{i}_score"] = []

for smiles in all_smiles:
    score = scoring_function([smiles])[0]
    if score > prop_delta:
        rationales = mcts(
            smiles=smiles,
            scoring_function=scoring_function,
            n_rollout=rollout,
            max_atoms=max_atoms,
            prop_delta=prop_delta,
            min_atoms=min_atoms,
            c_puct=c_puct,
        )
    else:
        rationales = []

    results_df["smiles"].append(smiles)
    results_df[property_for_interpretation].append(score)

    if len(rationales) == 0:
        for i in range(num_rationales_to_keep):
            results_df[f"rationale_{i}"].append(None)
            results_df[f"rationale_{i}_score"].append(None)
    else:
        min_size = min(len(x.atoms) for x in rationales)
        min_rationales = [x for x in rationales if len(x.atoms) == min_size]
        rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)

        for i in range(num_rationales_to_keep):
            if i < len(rats):
                results_df[f"rationale_{i}"].append(rats[i].smiles)
                results_df[f"rationale_{i}_score"].append(rats[i].P)
            else:
                results_df[f"rationale_{i}"].append(None)
                results_df[f"rationale_{i}_score"].append(None)
/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 10 11 12 13 14
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 11 12 13 14 15
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 8 9 10 11 12
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 7 8 9 10 11
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 1 2 3 4 5
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 0 1 3 4 5
[18:30:57] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4
[18:30:58] Can't kekulize mol.  Unkekulized atoms: 11 12 13 14 15
[18:30:58] Can't kekulize mol.  Unkekulized atoms: 8 9 10 11 12
[18:30:58] Can't kekulize mol.  Unkekulized atoms: 7 8 9 10 11
[18:30:58] Can't kekulize mol.  Unkekulized atoms: 10 11 12 13 14
CPU times: user 26min 27s, sys: 901 ms, total: 26min 28s
Wall time: 1min 43s
[19]:
results_df = pd.DataFrame(results_df)
results_df
[19]:
smiles lipo rationale_0 rationale_0_score rationale_1 rationale_1_score rationale_2 rationale_2_score rationale_3 rationale_3_score rationale_4 rationale_4_score
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 2.253542 c1ccc2c(c1)n[cH:1][nH:1]2 2.275024 None NaN None NaN None NaN None NaN
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... 2.235016 O=[SH:1]c1c[cH:1][cH:1]cc1[OH:1] 2.252582 c1c([OH:1])c([S:1][NH2:1])c[cH:1][cH:1]1 2.252185 c1c(N[CH3:1])[cH:1]c[cH:1]c1[SH:1] 2.251067 c1c([S:1][NH2:1])[cH:1]cc([OH:1])[cH:1]1 2.250288 c1c([NH2:1])[cH:1]c[cH:1]c1[S:1][NH2:1] 2.249267
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 2.245891 c1cc[cH:1]c([CH2:1][CH2:1][OH:1])c1 2.249289 O=[CH:1][CH2:1]c1cccc[cH:1]1 2.249207 c1cc[cH:1]c([C@@H]([CH3:1])[NH2:1])c1 2.247827 Clc1ccccc1[CH2:1][NH2:1] 2.245391 Clc1ccccc1[CH2:1][CH3:1] 2.243280
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 2.249847 c1c([CH3:1])[nH]c2s[cH:1]cc12 2.267990 Clc1cc2c[cH:1][nH]c2s1 2.267004 O=C1N(C[CH3:1])[CH:1]=[CH:1]C[CH2:1]1 2.211323 None NaN None NaN
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 2.228097 c1cc(C[CH2:1][NH2:1])c[cH:1]c1 2.247070 c1cc(C[CH2:1][CH3:1])c[cH:1]c1 2.245314 Cn1nc([CH3:1])cc1[CH2:1][NH2:1] 2.225729 C[CH2:1]c1cc([CH2:1][NH2:1])[nH:1]n1 2.223793 c1c([CH3:1])n[nH:1]c1[CH2:1]N[CH3:1] 2.223478