Source code for chemprop.train.make_predictions

from collections import OrderedDict
import csv
from typing import List, Optional, Union, Tuple

import numpy as np

from chemprop.args import PredictArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset, StandardScaler, AtomBondScaler
from chemprop.utils import load_args, load_checkpoint, load_scalers, makedirs, timeit, update_prediction_args
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, set_adding_hs, set_keeping_atom_map, reset_featurization_parameters
from chemprop.models import MoleculeModel
from chemprop.uncertainty import UncertaintyCalibrator, build_uncertainty_calibrator, UncertaintyEstimator, build_uncertainty_evaluator
from chemprop.multitask_utils import reshape_values


[docs]def load_model(args: PredictArgs, generator: bool = False): """ Function to load a model or ensemble of models from file. If generator is True, a generator of the respective model and scaler objects is returned (memory efficient), else the full list (holding all models in memory, necessary for preloading). :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for loading data and a model and making predictions. :param generator: A boolean to return a generator instead of a list of models and scalers. :return: A tuple of updated prediction arguments, training arguments, a list or generator object of models, a list or generator object of scalers, the number of tasks and their respective names. """ print('Loading training args') train_args = load_args(args.checkpoint_paths[0]) num_tasks, task_names = train_args.num_tasks, train_args.task_names update_prediction_args(predict_args=args, train_args=train_args) args: Union[PredictArgs, TrainArgs] # Load model and scalers models = ( load_checkpoint(checkpoint_path, device=args.device) for checkpoint_path in args.checkpoint_paths ) scalers = ( load_scalers(checkpoint_path) for checkpoint_path in args.checkpoint_paths ) if not generator: models = list(models) scalers = list(scalers) return args, train_args, models, scalers, num_tasks, task_names
[docs]def load_data(args: PredictArgs, smiles: List[List[str]]): """ Function to load data from a list of smiles or a file. :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for loading data and a model and making predictions. :param smiles: A list of list of smiles, or None if data is to be read from file :return: A tuple of a :class:`~chemprop.data.MoleculeDataset` containing all datapoints, a :class:`~chemprop.data.MoleculeDataset` containing only valid datapoints, a :class:`~chemprop.data.MoleculeDataLoader` and a dictionary mapping full to valid indices. """ print("Loading data") if smiles is not None: full_data = get_data_from_smiles( smiles=smiles, skip_invalid_smiles=False, features_generator=args.features_generator, ) else: full_data = get_data( path=args.test_path, smiles_columns=args.smiles_columns, target_columns=[], ignore_columns=[], skip_invalid_smiles=False, args=args, store_row=not args.drop_extra_columns, ) print("Validating SMILES") full_to_valid_indices = {} valid_index = 0 for full_index in range(len(full_data)): if all(mol is not None for mol in full_data[full_index].mol): full_to_valid_indices[full_index] = valid_index valid_index += 1 test_data = MoleculeDataset( [full_data[i] for i in sorted(full_to_valid_indices.keys())] ) print(f"Test size = {len(test_data):,}") # Create data loader test_data_loader = MoleculeDataLoader( dataset=test_data, batch_size=args.batch_size, num_workers=args.num_workers ) return full_data, test_data, test_data_loader, full_to_valid_indices
[docs]def set_features(args: PredictArgs, train_args: TrainArgs): """ Function to set extra options. :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for loading data and a model and making predictions. :param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. """ reset_featurization_parameters() if args.atom_descriptors == "feature": set_extra_atom_fdim(train_args.atom_features_size) if args.bond_descriptors == "feature": set_extra_bond_fdim(train_args.bond_features_size) # set explicit H option and reaction option set_explicit_h(train_args.explicit_h) set_adding_hs(args.adding_h) set_keeping_atom_map(args.keeping_atom_map) if train_args.reaction: set_reaction(train_args.reaction, train_args.reaction_mode) elif train_args.reaction_solvent: set_reaction(True, train_args.reaction_mode)
[docs]def predict_and_save( args: PredictArgs, train_args: TrainArgs, test_data: MoleculeDataset, task_names: List[str], num_tasks: int, test_data_loader: MoleculeDataLoader, full_data: MoleculeDataset, full_to_valid_indices: dict, models: List[MoleculeModel], scalers: List[Union[StandardScaler, AtomBondScaler]], num_models: int, calibrator: UncertaintyCalibrator = None, return_invalid_smiles: bool = False, save_results: bool = True, ): """ Function to predict with a model and save the predictions to file. :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for loading data and a model and making predictions. :param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param test_data: A :class:`~chemprop.data.MoleculeDataset` containing valid datapoints. :param task_names: A list of task names. :param num_tasks: Number of tasks. :param test_data_loader: A :class:`~chemprop.data.MoleculeDataLoader` to load the test data. :param full_data: A :class:`~chemprop.data.MoleculeDataset` containing all (valid and invalid) datapoints. :param full_to_valid_indices: A dictionary dictionary mapping full to valid indices. :param models: A list or generator object of :class:`~chemprop.models.MoleculeModel`\ s. :param scalers: A list or generator object of :class:`~chemprop.features.scaler.StandardScaler` objects. :param num_models: The number of models included in the models and scalers input. :param calibrator: A :class: `~chemprop.uncertainty.UncertaintyCalibrator` object, for use in calibrating uncertainty predictions. :param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions. :param save_results: Whether to save the predictions in a csv. Function returns the predictions regardless. :return: A list of lists of target predictions. """ estimator = UncertaintyEstimator( test_data=test_data, test_data_loader=test_data_loader, uncertainty_method=args.uncertainty_method, models=models, scalers=scalers, num_models=num_models, dataset_type=args.dataset_type, loss_function=args.loss_function, uncertainty_dropout_p=args.uncertainty_dropout_p, conformal_alpha=args.conformal_alpha, dropout_sampling_size=args.dropout_sampling_size, individual_ensemble_predictions=args.individual_ensemble_predictions, spectra_phase_mask=getattr(train_args, "spectra_phase_mask", None), ) preds, unc = estimator.calculate_uncertainty( calibrator=calibrator ) # preds and unc are lists of shape(data,tasks) if args.loss_function == "quantile_interval": task_names = task_names[:len(task_names) // 2] if calibrator is not None and args.is_atom_bond_targets and args.calibration_method == "isotonic": unc = reshape_values(unc, test_data, len(args.atom_targets), len(args.bond_targets)) if args.individual_ensemble_predictions: individual_preds = ( estimator.individual_predictions() ) # shape(data, tasks, ensemble) or (data, tasks, classes, ensemble) if args.evaluation_methods is not None: evaluation_data = get_data( path=args.test_path, smiles_columns=args.smiles_columns, target_columns=task_names, args=args, features_path=args.features_path, features_generator=args.features_generator, phase_features_path=args.phase_features_path, atom_descriptors_path=args.atom_descriptors_path, bond_descriptors_path=args.bond_descriptors_path, max_data_size=args.max_data_size, loss_function=args.loss_function, ) evaluators = [] for evaluation_method in args.evaluation_methods: evaluator = build_uncertainty_evaluator( evaluation_method=evaluation_method, calibration_method=args.calibration_method, uncertainty_method=args.uncertainty_method, dataset_type=args.dataset_type, loss_function=args.loss_function, calibrator=calibrator, is_atom_bond_targets=args.is_atom_bond_targets, ) evaluators.append(evaluator) else: evaluators = None if evaluators is not None: evaluations = [] print(f"Evaluating uncertainty for tasks {task_names}") for evaluator in evaluators: evaluation = evaluator.evaluate( targets=evaluation_data.targets(), preds=preds, uncertainties=unc, mask=evaluation_data.mask() ) evaluations.append(evaluation) print( f"Using evaluation method {evaluator.evaluation_method}: {evaluation}" ) else: evaluations = None if args.dataset_type == "multiclass": num_tasks = num_tasks * args.multiclass_num_classes if args.uncertainty_method == "spectra_roundrobin": num_unc_tasks = 1 elif args.uncertainty_method == "dirichlet" and args.dataset_type == "multiclass": num_unc_tasks = num_tasks // args.multiclass_num_classes # dirichlet only returns an uncertainty for each task rather than each class elif args.calibration_method == "conformal_regression": num_unc_tasks = 2 * num_tasks elif args.calibration_method == "conformal" and args.dataset_type == "classification": num_unc_tasks = 2 * num_tasks else: num_unc_tasks = num_tasks # Save results if save_results: print(f"Saving predictions to {args.preds_path}") assert len(test_data) == len(preds) assert len(test_data) == len(unc) makedirs(args.preds_path, isfile=True) # Set multiclass column names, update num_tasks definitions if args.dataset_type == "multiclass": original_task_names = task_names task_names = [ f"{name}_class_{i}" for name in task_names for i in range(args.multiclass_num_classes) ] # Copy predictions over to full_data for full_index, datapoint in enumerate(full_data): valid_index = full_to_valid_indices.get(full_index, None) if valid_index is not None: d_preds = preds[valid_index] d_unc = unc[valid_index] if args.individual_ensemble_predictions: ind_preds = individual_preds[valid_index] else: d_preds = ["Invalid SMILES"] * num_tasks d_unc = ["Invalid SMILES"] * num_unc_tasks if args.individual_ensemble_predictions: ind_preds = [["Invalid SMILES"] * len(args.checkpoint_paths)] * num_tasks # Reshape multiclass to merge task and class dimension, with updated num_tasks if args.dataset_type == "multiclass": d_preds = np.array(d_preds).reshape((num_tasks)) d_unc = np.array(d_unc).reshape((num_unc_tasks)) if args.individual_ensemble_predictions: ind_preds = ind_preds.reshape( (num_tasks, len(args.checkpoint_paths)) ) # If extra columns have been dropped, add back in SMILES columns if args.drop_extra_columns: datapoint.row = OrderedDict() smiles_columns = args.smiles_columns for column, smiles in zip(smiles_columns, datapoint.smiles): datapoint.row[column] = smiles # Add predictions columns if args.uncertainty_method == "spectra_roundrobin": unc_names = [estimator.label] elif args.uncertainty_method == "conformal_quantile_regression" and args.calibration_method is None: unc_names = [f"{name}_{args.conformal_alpha}_half_interval" for name in task_names] elif args.calibration_method == "conformal_regression" and args.calibration_path is None: unc_names = [] elif args.calibration_method == "conformal" and args.dataset_type == "classification": unc_names = [f"{name}_{estimator.label}_in_set" for name in task_names] + [ f"{name}_{estimator.label}_out_set" for name in task_names ] else: unc_names = [name + f"_{estimator.label}" for name in task_names] for pred_name, pred in zip(task_names, d_preds): datapoint.row[pred_name] = pred for unc_name, un in zip(unc_names, d_unc): if ( args.uncertainty_method is not None or args.calibration_method is not None ): datapoint.row[unc_name] = un if args.individual_ensemble_predictions: for pred_name, model_preds in zip(task_names, ind_preds): for idx, pred in enumerate(model_preds): datapoint.row[pred_name + f"_model_{idx}"] = pred # Save with open(args.preds_path, 'w', newline="") as f: writer = csv.DictWriter(f, fieldnames=full_data[0].row.keys()) writer.writeheader() for datapoint in full_data: writer.writerow(datapoint.row) if evaluations is not None and args.evaluation_scores_path is not None: print(f"Saving uncertainty evaluations to {args.evaluation_scores_path}") if args.dataset_type == "multiclass": task_names = original_task_names with open(args.evaluation_scores_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["evaluation_method"] + task_names) for i, evaluation_method in enumerate(args.evaluation_methods): writer.writerow([evaluation_method] + evaluations[i]) if return_invalid_smiles: full_preds = [] full_unc = [] for full_index in range(len(full_data)): valid_index = full_to_valid_indices.get(full_index, None) if valid_index is not None: pred = preds[valid_index] un = unc[valid_index] else: pred = ["Invalid SMILES"] * num_tasks un = ["Invalid SMILES"] * num_unc_tasks full_preds.append(pred) full_unc.append(un) return full_preds, full_unc else: return preds, unc
[docs]@timeit() def make_predictions( args: PredictArgs, smiles: List[List[str]] = None, model_objects: Tuple[ PredictArgs, TrainArgs, List[MoleculeModel], List[Union[StandardScaler, AtomBondScaler]], int, List[str], ] = None, calibrator: UncertaintyCalibrator = None, return_invalid_smiles: bool = True, return_index_dict: bool = False, return_uncertainty: bool = False, ) -> List[List[Optional[float]]]: """ Loads data and a trained model and uses the model to make predictions on the data. If SMILES are provided, then makes predictions on smiles. Otherwise makes predictions on :code:`args.test_data`. :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for loading data and a model and making predictions. :param smiles: List of list of SMILES to make predictions on. :param model_objects: Tuple of output of load_model function which can be called separately outside this function. Preloaded model objects should have used the non-generator option for load_model if the objects are to be used multiple times or are intended to be used for calibration as well. :param calibrator: A :class: `~chemprop.uncertainty.UncertaintyCalibrator` object, for use in calibrating uncertainty predictions. Can be preloaded and provided as a function input or constructed within the function from arguments. The models and scalers used to initiate the calibrator must be lists instead of generators if the same calibrator is to be used multiple times or if the same models and scalers objects are also part of the provided model_objects input. :param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions. :param return_index_dict: Whether to return the prediction results as a dictionary keyed from the initial data indexes. :param return_uncertainty: Whether to return uncertainty predictions alongside the model value predictions. :return: A list of lists of target predictions. If returning uncertainty, a tuple containing first prediction values then uncertainty estimates. """ if model_objects: (args, train_args, models, scalers, num_tasks, task_names) = model_objects else: (args, train_args, models, scalers, num_tasks, task_names) = load_model( args, generator=True ) num_models = len(args.checkpoint_paths) set_features(args, train_args) # Note: to get the invalid SMILES for your data, use the get_invalid_smiles_from_file or get_invalid_smiles_from_list functions from data/utils.py full_data, test_data, test_data_loader, full_to_valid_indices = load_data(args, smiles) if args.uncertainty_method is not None and args.calibration_method in [ "conformal_regression", "conformal_quantile_regression", ]: raise ValueError("Conformal regression is not compatible with an uncertainty method") if args.uncertainty_method is None and ( args.calibration_method is not None or args.evaluation_methods is not None ): if args.dataset_type in ["classification", "multiclass"]: args.uncertainty_method = "classification" elif args.calibration_method == "conformal_regression": if args.loss_function == "quantile_interval": raise ValueError( "For a model trained on the `quantile_interval` loss function, the calibration method should be assigned as `conformal_quantile_regression` instead of `conformal_regression`." ) args.uncertainty_method = "conformal_regression" elif args.calibration_method == "conformal_quantile_regression": if args.loss_function != "quantile_interval": raise ValueError( "The calibration method `conformal_quantile_regression` only supports regression models trained on the `quantile_interval` loss function." ) args.uncertainty_method = "conformal_quantile_regression" else: raise ValueError( "Cannot calibrate or evaluate uncertainty without selection of an uncertainty method." ) if args.calibration_method is None and args.loss_function == "quantile_interval": args.uncertainty_method = "conformal_quantile_regression" if calibrator is None and args.calibration_path is not None: calibration_data = get_data( path=args.calibration_path, smiles_columns=args.smiles_columns, target_columns=task_names, args=args, features_path=args.calibration_features_path, features_generator=args.features_generator, phase_features_path=args.calibration_phase_features_path, atom_descriptors_path=args.calibration_atom_descriptors_path, bond_descriptors_path=args.calibration_bond_descriptors_path, max_data_size=args.max_data_size, loss_function=args.loss_function, ) calibration_data_loader = MoleculeDataLoader( dataset=calibration_data, batch_size=args.batch_size, num_workers=args.num_workers, ) if isinstance(models, List) and isinstance(scalers, List): calibration_models = models calibration_scalers = scalers else: calibration_model_objects = load_model(args, generator=True) calibration_models = calibration_model_objects[2] calibration_scalers = calibration_model_objects[3] calibrator = build_uncertainty_calibrator( calibration_method=args.calibration_method, uncertainty_method=args.uncertainty_method, interval_percentile=args.calibration_interval_percentile, regression_calibrator_metric=args.regression_calibrator_metric, calibration_data=calibration_data, calibration_data_loader=calibration_data_loader, models=calibration_models, scalers=calibration_scalers, num_models=num_models, dataset_type=args.dataset_type, loss_function=args.loss_function, uncertainty_dropout_p=args.uncertainty_dropout_p, conformal_alpha=args.conformal_alpha, dropout_sampling_size=args.dropout_sampling_size, spectra_phase_mask=getattr(train_args, "spectra_phase_mask", None), ) # Edge case if empty list of smiles is provided if len(test_data) == 0: preds = [None] * len(full_data) unc = [None] * len(full_data) else: preds, unc = predict_and_save( args=args, train_args=train_args, test_data=test_data, task_names=task_names, num_tasks=num_tasks, test_data_loader=test_data_loader, full_data=full_data, full_to_valid_indices=full_to_valid_indices, models=models, scalers=scalers, num_models=num_models, calibrator=calibrator, return_invalid_smiles=return_invalid_smiles, ) if return_index_dict: preds_dict = {} unc_dict = {} for i in range(len(full_data)): if return_invalid_smiles: preds_dict[i] = preds[i] unc_dict[i] = unc[i] else: valid_index = full_to_valid_indices.get(i, None) if valid_index is not None: preds_dict[i] = preds[valid_index] unc_dict[i] = unc[valid_index] if return_uncertainty: return preds_dict, unc_dict else: return preds_dict else: if return_uncertainty: return preds, unc else: return preds
[docs]def chemprop_predict() -> None: """Parses Chemprop predicting arguments and runs prediction using a trained Chemprop model. This is the entry point for the command line command :code:`chemprop_predict`. """ make_predictions(args=PredictArgs().parse_args())