Source code for chemprop.train.cross_validate

from collections import defaultdict
import csv
import json
from logging import Logger
import os
import sys
from typing import Callable, Dict, List, Tuple
import subprocess

import numpy as np
import pandas as pd

from .run_training import run_training
from chemprop.args import TrainArgs
from chemprop.constants import TEST_SCORES_FILE_NAME, TRAIN_LOGGER_NAME
from chemprop.data import get_data, get_task_names, MoleculeDataset, validate_dataset_type
from chemprop.utils import create_logger, makedirs, timeit, multitask_mean
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_explicit_h, set_adding_hs, set_keeping_atom_map, set_reaction, reset_featurization_parameters


[docs]@timeit(logger_name=TRAIN_LOGGER_NAME) def cross_validate(args: TrainArgs, train_func: Callable[[TrainArgs, MoleculeDataset, Logger], Dict[str, List[float]]] ) -> Tuple[float, float]: """ Runs k-fold cross-validation. For each of k splits (folds) of the data, trains and tests a model on that split and aggregates the performance across folds. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for loading data and training the Chemprop model. :param train_func: Function which runs training. :return: A tuple containing the mean and standard deviation performance across folds. """ logger = create_logger(name=TRAIN_LOGGER_NAME, save_dir=args.save_dir, quiet=args.quiet) if logger is not None: debug, info = logger.debug, logger.info else: debug = info = print # Initialize relevant variables init_seed = args.seed save_dir = args.save_dir args.task_names = get_task_names( path=args.data_path, smiles_columns=args.smiles_columns, target_columns=args.target_columns, ignore_columns=args.ignore_columns, loss_function=args.loss_function, ) args.quantiles = [args.quantile_loss_alpha / 2] * (args.num_tasks // 2) + [1 - args.quantile_loss_alpha / 2] * ( args.num_tasks // 2 ) # Print command line debug('Command line') debug(f'python {" ".join(sys.argv)}') # Print args debug('Args') debug(args) # Save args makedirs(args.save_dir) try: args.save(os.path.join(args.save_dir, 'args.json')) except subprocess.CalledProcessError: debug('Could not write the reproducibility section of the arguments to file, thus omitting this section.') args.save(os.path.join(args.save_dir, 'args.json'), with_reproducibility=False) # set explicit H option and reaction option reset_featurization_parameters(logger=logger) set_explicit_h(args.explicit_h) set_adding_hs(args.adding_h) set_keeping_atom_map(args.keeping_atom_map) if args.reaction: set_reaction(args.reaction, args.reaction_mode) elif args.reaction_solvent: set_reaction(True, args.reaction_mode) # Get data debug('Loading data') data = get_data( path=args.data_path, args=args, logger=logger, skip_none_targets=True, data_weights_path=args.data_weights_path ) validate_dataset_type(data, dataset_type=args.dataset_type) args.features_size = data.features_size() if args.atom_descriptors == 'descriptor': args.atom_descriptors_size = data.atom_descriptors_size() elif args.atom_descriptors == 'feature': args.atom_features_size = data.atom_features_size() set_extra_atom_fdim(args.atom_features_size) if args.bond_descriptors == 'descriptor': args.bond_descriptors_size = data.bond_descriptors_size() elif args.bond_descriptors == 'feature': args.bond_features_size = data.bond_features_size() set_extra_bond_fdim(args.bond_features_size) debug(f'Number of tasks = {args.num_tasks}') if args.target_weights is not None and len(args.target_weights) != args.num_tasks: raise ValueError('The number of provided target weights must match the number and order of the prediction tasks') # Run training on different random seeds for each fold all_scores = defaultdict(list) for fold_num in range(args.num_folds): info(f'Fold {fold_num}') args.seed = init_seed + fold_num args.save_dir = os.path.join(save_dir, f'fold_{fold_num}') makedirs(args.save_dir) data.reset_features_and_targets() # If resuming experiment, load results from trained models test_scores_path = os.path.join(args.save_dir, 'test_scores.json') if args.resume_experiment and os.path.exists(test_scores_path): print('Loading scores') with open(test_scores_path) as f: model_scores = json.load(f) # Otherwise, train the models else: model_scores = train_func(args, data, logger) for metric, scores in model_scores.items(): all_scores[metric].append(scores) all_scores = dict(all_scores) # Convert scores to numpy arrays for metric, scores in all_scores.items(): all_scores[metric] = np.array(scores) # Report results info(f'{args.num_folds}-fold cross validation') # Report scores for each fold contains_nan_scores = False for fold_num in range(args.num_folds): for metric, scores in all_scores.items(): info(f'\tSeed {init_seed + fold_num} ==> test {metric} = ' f'{multitask_mean(scores=scores[fold_num], metric=metric, ignore_nan_metrics=args.ignore_nan_metrics):.6f}') if args.show_individual_scores: if args.loss_function == "quantile_interval" and metric == "quantile": num_tasks = len(args.task_names) // 2 task_names = args.task_names[:num_tasks] task_names = [f"{task_name} lower" for task_name in task_names] + [ f"{task_name} upper" for task_name in task_names] else: task_names = args.task_names for task_name, score in zip(task_names, scores[fold_num]): info(f'\t\tSeed {init_seed + fold_num} ==> test {task_name} {metric} = {score:.6f}') if np.isnan(score): contains_nan_scores = True # Report scores across folds for metric, scores in all_scores.items(): avg_scores = multitask_mean( scores=scores, axis=1, metric=metric, ignore_nan_metrics=args.ignore_nan_metrics ) # average score for each model across tasks mean_score, std_score = np.mean(avg_scores), np.std(avg_scores) info(f'Overall test {metric} = {mean_score:.6f} +/- {std_score:.6f}') if args.show_individual_scores: for task_num, task_name in enumerate(task_names): info(f'\tOverall test {task_name} {metric} = ' f'{np.mean(scores[:, task_num]):.6f} +/- {np.std(scores[:, task_num]):.6f}') if contains_nan_scores: info("The metric scores observed for some fold test splits contain 'nan' values. \ This can occur when the test set does not meet the requirements \ for a particular metric, such as having no valid instances of one \ task in the test set or not having positive examples for some classification metrics. \ Before v1.5.1, the default behavior was to ignore nan values in individual folds or tasks \ and still return an overall average for the remaining folds or tasks. The behavior now \ is to include them in the average, converting overall average metrics to 'nan' as well.") # Save scores with open(os.path.join(save_dir, TEST_SCORES_FILE_NAME), 'w') as f: writer = csv.writer(f) header = ['Task'] for metric in args.metrics: header += [f'Mean {metric}', f'Standard deviation {metric}'] + \ [f'Fold {i} {metric}' for i in range(args.num_folds)] writer.writerow(header) if args.dataset_type == 'spectra': # spectra data type has only one score to report row = ['spectra'] for metric, scores in all_scores.items(): task_scores = scores[:,0] mean, std = np.mean(task_scores), np.std(task_scores) row += [mean, std] + task_scores.tolist() writer.writerow(row) else: # all other data types, separate scores by task if args.loss_function == "quantile_interval" and metric == "quantile": num_tasks = len(args.task_names) // 2 task_names = args.task_names[:num_tasks] task_names = [f"{task_name} (lower quantile)" for task_name in task_names] + [ f"{task_name} (upper quantile)" for task_name in task_names] else: task_names = args.task_names for task_num, task_name in enumerate(task_names): row = [task_name] for metric, scores in all_scores.items(): task_scores = scores[:, task_num] mean, std = np.mean(task_scores), np.std(task_scores) row += [mean, std] + task_scores.tolist() writer.writerow(row) # Determine mean and std score of main metric avg_scores = multitask_mean( scores=all_scores[args.metric], metric=args.metric, axis=1, ignore_nan_metrics=args.ignore_nan_metrics ) mean_score, std_score = np.mean(avg_scores), np.std(avg_scores) # Optionally merge and save test preds if args.save_preds: all_preds = pd.concat([pd.read_csv(os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv')) for fold_num in range(args.num_folds)]) all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False) return mean_score, std_score
[docs]def chemprop_train() -> None: """Parses Chemprop training arguments and trains (cross-validates) a Chemprop model. This is the entry point for the command line command :code:`chemprop_train`. """ cross_validate(args=TrainArgs().parse_args(), train_func=run_training)