Source code for chemprop.train.run_training

import json
from logging import Logger
import os
from typing import Dict, List

import numpy as np
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 
import pandas as pd
from tensorboardX import SummaryWriter
import torch
from tqdm import trange
from torch.optim.lr_scheduler import ExponentialLR

from .evaluate import evaluate, evaluate_predictions
from .predict import predict
from .train import train
from .loss_functions import get_loss_func
from chemprop.spectra_utils import normalize_spectra, load_phase_mask
from chemprop.args import TrainArgs
from chemprop.constants import MODEL_FILE_NAME
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data
from chemprop.models import MoleculeModel
from chemprop.nn_utils import param_count, param_count_all
from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \
    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean


[docs]def run_training(args: TrainArgs, data: MoleculeDataset, logger: Logger = None) -> Dict[str, List[float]]: """ Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for loading data and training the Chemprop model. :param data: A :class:`~chemprop.data.MoleculeDataset` containing the data. :param logger: A logger to record output. :return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task. """ if logger is not None: debug, info = logger.debug, logger.info else: debug = info = print # Set pytorch seed for random initial weights torch.manual_seed(args.pytorch_seed) # Split data debug(f'Splitting data with seed {args.seed}') if args.separate_test_path: test_data = get_data(path=args.separate_test_path, args=args, features_path=args.separate_test_features_path, atom_descriptors_path=args.separate_test_atom_descriptors_path, bond_descriptors_path=args.separate_test_bond_descriptors_path, phase_features_path=args.separate_test_phase_features_path, constraints_path=args.separate_test_constraints_path, smiles_columns=args.smiles_columns, loss_function=args.loss_function, logger=logger) if args.separate_val_path: val_data = get_data(path=args.separate_val_path, args=args, features_path=args.separate_val_features_path, atom_descriptors_path=args.separate_val_atom_descriptors_path, bond_descriptors_path=args.separate_val_bond_descriptors_path, phase_features_path=args.separate_val_phase_features_path, constraints_path=args.separate_val_constraints_path, smiles_columns=args.smiles_columns, loss_function=args.loss_function, logger=logger) if args.separate_val_path and args.separate_test_path: train_data = data elif args.separate_val_path: train_data, _, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, key_molecule_index=args.split_key_molecule, seed=args.seed, num_folds=args.num_folds, args=args, logger=logger) elif args.separate_test_path: train_data, val_data, _ = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, key_molecule_index=args.split_key_molecule, seed=args.seed, num_folds=args.num_folds, args=args, logger=logger) else: train_data, val_data, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, key_molecule_index=args.split_key_molecule, seed=args.seed, num_folds=args.num_folds, args=args, logger=logger) if args.dataset_type == 'classification': class_sizes = get_class_sizes(data) debug('Class sizes') for i, task_class_sizes in enumerate(class_sizes): debug(f'{args.task_names[i]} ' f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}') train_class_sizes = get_class_sizes(train_data, proportion=False) args.train_class_sizes = train_class_sizes if args.save_smiles_splits: save_smiles_splits( data_path=args.data_path, save_dir=args.save_dir, task_names=args.task_names, features_path=args.features_path, constraints_path=args.constraints_path, train_data=train_data, val_data=val_data, test_data=test_data, smiles_columns=args.smiles_columns, loss_function=args.loss_function, logger=logger, ) if args.features_scaling: features_scaler = train_data.normalize_features(replace_nan_token=0) val_data.normalize_features(features_scaler) test_data.normalize_features(features_scaler) else: features_scaler = None if args.atom_descriptor_scaling and args.atom_descriptors is not None: atom_descriptor_scaler = train_data.normalize_features(replace_nan_token=0, scale_atom_descriptors=True) val_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True) test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True) else: atom_descriptor_scaler = None if args.bond_descriptor_scaling and args.bond_descriptors is not None: bond_descriptor_scaler = train_data.normalize_features(replace_nan_token=0, scale_bond_descriptors=True) val_data.normalize_features(bond_descriptor_scaler, scale_bond_descriptors=True) test_data.normalize_features(bond_descriptor_scaler, scale_bond_descriptors=True) else: bond_descriptor_scaler = None args.train_data_size = len(train_data) debug(f'Total size = {len(data):,} | ' f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}') if len(val_data) == 0: raise ValueError('The validation data split is empty. During normal chemprop training (non-sklearn functions), \ a validation set is required to conduct early stopping according to the selected evaluation metric. This \ may have occurred because validation data provided with `--separate_val_path` was empty or contained only invalid molecules.') if len(test_data) == 0: debug('The test data split is empty. This may be either because splitting with no test set was selected, \ such as with `cv-no-test`, or because test data provided with `--separate_test_path` was empty or contained only invalid molecules. \ Performance on the test set will not be evaluated and metric scores will return `nan` for each task.') empty_test_set = True else: empty_test_set = False # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only) if args.dataset_type == 'regression': debug('Fitting scaler') if args.is_atom_bond_targets: scaler = None atom_bond_scaler = train_data.normalize_atom_bond_targets() else: scaler = train_data.normalize_targets() atom_bond_scaler = None args.spectra_phase_mask = None elif args.dataset_type == 'spectra': debug('Normalizing spectra and excluding spectra regions based on phase') args.spectra_phase_mask = load_phase_mask(args.spectra_phase_mask_path) for dataset in [train_data, test_data, val_data]: data_targets = normalize_spectra( spectra=dataset.targets(), phase_features=dataset.phase_features(), phase_mask=args.spectra_phase_mask, excluded_sub_value=None, threshold=args.spectra_target_floor, ) dataset.set_targets(data_targets) scaler = None atom_bond_scaler = None else: args.spectra_phase_mask = None scaler = None atom_bond_scaler = None # Get loss function loss_func = get_loss_func(args) # Set up test set evaluation test_smiles, test_targets = test_data.smiles(), test_data.targets() if args.dataset_type == 'multiclass': sum_test_preds = np.zeros((len(test_smiles), args.num_tasks, args.multiclass_num_classes)) elif args.is_atom_bond_targets: sum_test_preds = [] for tb in zip(*test_data.targets()): tb = np.concatenate(tb) sum_test_preds.append(np.zeros((tb.shape[0], 1))) sum_test_preds = np.array(sum_test_preds, dtype=object) else: sum_test_preds = np.zeros((len(test_smiles), args.num_tasks)) # Automatically determine whether to cache if len(data) <= args.cache_cutoff: set_cache_graph(True) num_workers = 0 else: set_cache_graph(False) num_workers = args.num_workers # Create data loaders train_data_loader = MoleculeDataLoader( dataset=train_data, batch_size=args.batch_size, num_workers=num_workers, class_balance=args.class_balance, shuffle=True, seed=args.seed ) val_data_loader = MoleculeDataLoader( dataset=val_data, batch_size=args.batch_size, num_workers=num_workers ) test_data_loader = MoleculeDataLoader( dataset=test_data, batch_size=args.batch_size, num_workers=num_workers ) if args.class_balance: debug(f'With class_balance, effective train size = {train_data_loader.iter_size:,}') # Train ensemble of models for model_idx in range(args.ensemble_size): # Tensorboard writer save_dir = os.path.join(args.save_dir, f'model_{model_idx}') makedirs(save_dir) try: writer = SummaryWriter(log_dir=save_dir) except: writer = SummaryWriter(logdir=save_dir) # Load/build model if args.checkpoint_paths is not None: debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}') model = load_checkpoint(args.checkpoint_paths[model_idx], logger=logger) else: debug(f'Building model {model_idx}') model = MoleculeModel(args) # Optionally, overwrite weights: if args.checkpoint_frzn is not None: debug(f'Loading and freezing parameters from {args.checkpoint_frzn}.') model = load_frzn_model(model=model, path=args.checkpoint_frzn, current_args=args, logger=logger) debug(model) if args.checkpoint_frzn is not None: debug(f'Number of unfrozen parameters = {param_count(model):,}') debug(f'Total number of parameters = {param_count_all(model):,}') else: debug(f'Number of parameters = {param_count_all(model):,}') if args.cuda: debug('Moving model to cuda') model = model.to(args.device) # Ensure that model is saved in correct location for evaluation if 0 epochs save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler, features_scaler, atom_descriptor_scaler, bond_descriptor_scaler, atom_bond_scaler, args) # Optimizers optimizer = build_optimizer(model, args) # Learning rate schedulers scheduler = build_lr_scheduler(optimizer, args) # Run training best_score = float('inf') if args.minimize_score else -float('inf') best_epoch, n_iter = 0, 0 for epoch in trange(args.epochs): debug(f'Epoch {epoch}') n_iter = train( model=model, data_loader=train_data_loader, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, atom_bond_scaler=atom_bond_scaler, logger=logger, writer=writer ) if isinstance(scheduler, ExponentialLR): scheduler.step() val_scores = evaluate( model=model, data_loader=val_data_loader, num_tasks=args.num_tasks, metrics=args.metrics, dataset_type=args.dataset_type, scaler=scaler, quantiles=args.quantiles, atom_bond_scaler=atom_bond_scaler, logger=logger ) for metric, scores in val_scores.items(): # Average validation score\ mean_val_score = multitask_mean( scores=scores, metric=metric, ignore_nan_metrics=args.ignore_nan_metrics ) debug(f'Validation {metric} = {mean_val_score:.6f}') writer.add_scalar(f'validation_{metric}', mean_val_score, n_iter) if args.show_individual_scores: if args.loss_function == "quantile_interval" and metric == "quantile": num_tasks = len(args.task_names) // 2 task_names = args.task_names[:num_tasks] task_names = [f"{task_name} lower" for task_name in task_names] + [ f"{task_name} upper" for task_name in task_names] else: task_names = args.task_names # Individual validation scores for task_name, val_score in zip(task_names, scores): debug(f'Validation {task_name} {metric} = {val_score:.6f}') writer.add_scalar(f'validation_{task_name}_{metric}', val_score, n_iter) # Save model checkpoint if improved validation score mean_val_score = multitask_mean( scores=val_scores[args.metric], metric=args.metric, ignore_nan_metrics=args.ignore_nan_metrics ) if args.minimize_score and mean_val_score < best_score or \ not args.minimize_score and mean_val_score > best_score: best_score, best_epoch = mean_val_score, epoch save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler, features_scaler, atom_descriptor_scaler, bond_descriptor_scaler, atom_bond_scaler, args) # Evaluate on test set using model with best validation score info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}') model = load_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), device=args.device, logger=logger) if empty_test_set: info(f'Model {model_idx} provided with no test set, no metric evaluation will be performed.') else: test_preds = predict( model=model, data_loader=test_data_loader, scaler=scaler, atom_bond_scaler=atom_bond_scaler ) test_scores = evaluate_predictions( preds=test_preds, targets=test_targets, num_tasks=args.num_tasks, metrics=args.metrics, dataset_type=args.dataset_type, is_atom_bond_targets=args.is_atom_bond_targets, gt_targets=test_data.gt_targets(), lt_targets=test_data.lt_targets(), quantiles=args.quantiles, logger=logger ) if len(test_preds) != 0: if args.is_atom_bond_targets: sum_test_preds += np.array(test_preds, dtype=object) else: sum_test_preds += np.array(test_preds) # Average test score for metric, scores in test_scores.items(): avg_test_score = np.nanmean(scores) info(f'Model {model_idx} test {metric} = {avg_test_score:.6f}') writer.add_scalar(f'test_{metric}', avg_test_score, 0) if args.show_individual_scores and args.dataset_type != 'spectra': # Individual test scores for task_name, test_score in zip(task_names, scores): info(f'Model {model_idx} test {task_name} {metric} = {test_score:.6f}') writer.add_scalar(f'test_{task_name}_{metric}', test_score, n_iter) writer.close() # Evaluate ensemble on test set if empty_test_set: ensemble_scores = { metric: [np.nan for task in args.task_names] for metric in args.metrics } else: avg_test_preds = (sum_test_preds / args.ensemble_size).tolist() ensemble_scores = evaluate_predictions( preds=avg_test_preds, targets=test_targets, num_tasks=args.num_tasks, metrics=args.metrics, dataset_type=args.dataset_type, is_atom_bond_targets=args.is_atom_bond_targets, gt_targets=test_data.gt_targets(), lt_targets=test_data.lt_targets(), quantiles=args.quantiles, logger=logger, ) for metric, scores in ensemble_scores.items(): # Average ensemble score mean_ensemble_test_score = multitask_mean( scores=scores, metric=metric, ignore_nan_metrics=args.ignore_nan_metrics ) info(f'Ensemble test {metric} = {mean_ensemble_test_score:.6f}') # Individual ensemble scores if args.show_individual_scores: for task_name, ensemble_score in zip(task_names, scores): info(f'Ensemble test {task_name} {metric} = {ensemble_score:.6f}') # Save scores with open(os.path.join(args.save_dir, 'test_scores.json'), 'w') as f: json.dump(ensemble_scores, f, indent=4, sort_keys=True) # Optionally save test preds if args.save_preds and not empty_test_set: test_preds_dataframe = pd.DataFrame(data={'smiles': test_data.smiles()}) if args.is_atom_bond_targets: n_atoms, n_bonds = test_data.number_of_atoms, test_data.number_of_bonds for i, atom_target in enumerate(args.atom_targets): values = np.split(np.array(avg_test_preds[i]).flatten(), np.cumsum(np.array(n_atoms)))[:-1] values = [list(v) for v in values] test_preds_dataframe[atom_target] = values for i, bond_target in enumerate(args.bond_targets): values = np.split(np.array(avg_test_preds[i+len(args.atom_targets)]).flatten(), np.cumsum(np.array(n_bonds)))[:-1] values = [list(v) for v in values] test_preds_dataframe[bond_target] = values else: if args.loss_function == "quantile_interval" and metric == "quantile": num_tasks = len(args.task_names) // 2 task_names = args.task_names[:num_tasks] avg_test_preds = np.array(avg_test_preds) num_data = avg_test_preds.shape[0] preds = avg_test_preds.reshape(num_data, 2, num_tasks).mean(axis=1) intervals = abs(np.diff(avg_test_preds.reshape(num_data, 2, num_tasks), axis=1) / 2) intervals = intervals.reshape(num_data, num_tasks) for i, task_name in enumerate(task_names): test_preds_dataframe[task_name] = [pred[i] for pred in preds] for i, task_name in enumerate(task_names): task_name = f"{task_name}_{args.quantile_loss_alpha}_half_interval" test_preds_dataframe[task_name] = [interval[i] for interval in intervals] else: for i, task_name in enumerate(args.task_names): test_preds_dataframe[task_name] = [pred[i] for pred in avg_test_preds] test_preds_dataframe.to_csv(os.path.join(args.save_dir, 'test_preds.csv'), index=False) return ensemble_scores