import json
from logging import Logger
import os
from typing import Dict, List
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
import pandas as pd
from tensorboardX import SummaryWriter
import torch
from tqdm import trange
from torch.optim.lr_scheduler import ExponentialLR
from .evaluate import evaluate, evaluate_predictions
from .predict import predict
from .train import train
from .loss_functions import get_loss_func
from chemprop.spectra_utils import normalize_spectra, load_phase_mask
from chemprop.args import TrainArgs
from chemprop.constants import MODEL_FILE_NAME
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data
from chemprop.models import MoleculeModel
from chemprop.nn_utils import param_count, param_count_all
from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \
save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean
[docs]def run_training(args: TrainArgs,
data: MoleculeDataset,
logger: Logger = None) -> Dict[str, List[float]]:
"""
Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score.
:param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
loading data and training the Chemprop model.
:param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
:param logger: A logger to record output.
:return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task.
"""
if logger is not None:
debug, info = logger.debug, logger.info
else:
debug = info = print
# Set pytorch seed for random initial weights
torch.manual_seed(args.pytorch_seed)
# Split data
debug(f'Splitting data with seed {args.seed}')
if args.separate_test_path:
test_data = get_data(path=args.separate_test_path,
args=args,
features_path=args.separate_test_features_path,
atom_descriptors_path=args.separate_test_atom_descriptors_path,
bond_descriptors_path=args.separate_test_bond_descriptors_path,
phase_features_path=args.separate_test_phase_features_path,
constraints_path=args.separate_test_constraints_path,
smiles_columns=args.smiles_columns,
loss_function=args.loss_function,
logger=logger)
if args.separate_val_path:
val_data = get_data(path=args.separate_val_path,
args=args,
features_path=args.separate_val_features_path,
atom_descriptors_path=args.separate_val_atom_descriptors_path,
bond_descriptors_path=args.separate_val_bond_descriptors_path,
phase_features_path=args.separate_val_phase_features_path,
constraints_path=args.separate_val_constraints_path,
smiles_columns=args.smiles_columns,
loss_function=args.loss_function,
logger=logger)
if args.separate_val_path and args.separate_test_path:
train_data = data
elif args.separate_val_path:
train_data, _, test_data = split_data(data=data,
split_type=args.split_type,
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
logger=logger)
elif args.separate_test_path:
train_data, val_data, _ = split_data(data=data,
split_type=args.split_type,
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
logger=logger)
else:
train_data, val_data, test_data = split_data(data=data,
split_type=args.split_type,
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
logger=logger)
if args.dataset_type == 'classification':
class_sizes = get_class_sizes(data)
debug('Class sizes')
for i, task_class_sizes in enumerate(class_sizes):
debug(f'{args.task_names[i]} '
f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
train_class_sizes = get_class_sizes(train_data, proportion=False)
args.train_class_sizes = train_class_sizes
if args.save_smiles_splits:
save_smiles_splits(
data_path=args.data_path,
save_dir=args.save_dir,
task_names=args.task_names,
features_path=args.features_path,
constraints_path=args.constraints_path,
train_data=train_data,
val_data=val_data,
test_data=test_data,
smiles_columns=args.smiles_columns,
loss_function=args.loss_function,
logger=logger,
)
if args.features_scaling:
features_scaler = train_data.normalize_features(replace_nan_token=0)
val_data.normalize_features(features_scaler)
test_data.normalize_features(features_scaler)
else:
features_scaler = None
if args.atom_descriptor_scaling and args.atom_descriptors is not None:
atom_descriptor_scaler = train_data.normalize_features(replace_nan_token=0, scale_atom_descriptors=True)
val_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
else:
atom_descriptor_scaler = None
if args.bond_descriptor_scaling and args.bond_descriptors is not None:
bond_descriptor_scaler = train_data.normalize_features(replace_nan_token=0, scale_bond_descriptors=True)
val_data.normalize_features(bond_descriptor_scaler, scale_bond_descriptors=True)
test_data.normalize_features(bond_descriptor_scaler, scale_bond_descriptors=True)
else:
bond_descriptor_scaler = None
args.train_data_size = len(train_data)
debug(f'Total size = {len(data):,} | '
f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
if len(val_data) == 0:
raise ValueError('The validation data split is empty. During normal chemprop training (non-sklearn functions), \
a validation set is required to conduct early stopping according to the selected evaluation metric. This \
may have occurred because validation data provided with `--separate_val_path` was empty or contained only invalid molecules.')
if len(test_data) == 0:
debug('The test data split is empty. This may be either because splitting with no test set was selected, \
such as with `cv-no-test`, or because test data provided with `--separate_test_path` was empty or contained only invalid molecules. \
Performance on the test set will not be evaluated and metric scores will return `nan` for each task.')
empty_test_set = True
else:
empty_test_set = False
# Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
if args.dataset_type == 'regression':
debug('Fitting scaler')
if args.is_atom_bond_targets:
scaler = None
atom_bond_scaler = train_data.normalize_atom_bond_targets()
else:
scaler = train_data.normalize_targets()
atom_bond_scaler = None
args.spectra_phase_mask = None
elif args.dataset_type == 'spectra':
debug('Normalizing spectra and excluding spectra regions based on phase')
args.spectra_phase_mask = load_phase_mask(args.spectra_phase_mask_path)
for dataset in [train_data, test_data, val_data]:
data_targets = normalize_spectra(
spectra=dataset.targets(),
phase_features=dataset.phase_features(),
phase_mask=args.spectra_phase_mask,
excluded_sub_value=None,
threshold=args.spectra_target_floor,
)
dataset.set_targets(data_targets)
scaler = None
atom_bond_scaler = None
else:
args.spectra_phase_mask = None
scaler = None
atom_bond_scaler = None
# Get loss function
loss_func = get_loss_func(args)
# Set up test set evaluation
test_smiles, test_targets = test_data.smiles(), test_data.targets()
if args.dataset_type == 'multiclass':
sum_test_preds = np.zeros((len(test_smiles), args.num_tasks, args.multiclass_num_classes))
elif args.is_atom_bond_targets:
sum_test_preds = []
for tb in zip(*test_data.targets()):
tb = np.concatenate(tb)
sum_test_preds.append(np.zeros((tb.shape[0], 1)))
sum_test_preds = np.array(sum_test_preds, dtype=object)
else:
sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))
# Automatically determine whether to cache
if len(data) <= args.cache_cutoff:
set_cache_graph(True)
num_workers = 0
else:
set_cache_graph(False)
num_workers = args.num_workers
# Create data loaders
train_data_loader = MoleculeDataLoader(
dataset=train_data,
batch_size=args.batch_size,
num_workers=num_workers,
class_balance=args.class_balance,
shuffle=True,
seed=args.seed
)
val_data_loader = MoleculeDataLoader(
dataset=val_data,
batch_size=args.batch_size,
num_workers=num_workers
)
test_data_loader = MoleculeDataLoader(
dataset=test_data,
batch_size=args.batch_size,
num_workers=num_workers
)
if args.class_balance:
debug(f'With class_balance, effective train size = {train_data_loader.iter_size:,}')
# Train ensemble of models
for model_idx in range(args.ensemble_size):
# Tensorboard writer
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
makedirs(save_dir)
try:
writer = SummaryWriter(log_dir=save_dir)
except:
writer = SummaryWriter(logdir=save_dir)
# Load/build model
if args.checkpoint_paths is not None:
debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
model = load_checkpoint(args.checkpoint_paths[model_idx], logger=logger)
else:
debug(f'Building model {model_idx}')
model = MoleculeModel(args)
# Optionally, overwrite weights:
if args.checkpoint_frzn is not None:
debug(f'Loading and freezing parameters from {args.checkpoint_frzn}.')
model = load_frzn_model(model=model, path=args.checkpoint_frzn, current_args=args, logger=logger)
debug(model)
if args.checkpoint_frzn is not None:
debug(f'Number of unfrozen parameters = {param_count(model):,}')
debug(f'Total number of parameters = {param_count_all(model):,}')
else:
debug(f'Number of parameters = {param_count_all(model):,}')
if args.cuda:
debug('Moving model to cuda')
model = model.to(args.device)
# Ensure that model is saved in correct location for evaluation if 0 epochs
save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler,
features_scaler, atom_descriptor_scaler, bond_descriptor_scaler,
atom_bond_scaler, args)
# Optimizers
optimizer = build_optimizer(model, args)
# Learning rate schedulers
scheduler = build_lr_scheduler(optimizer, args)
# Run training
best_score = float('inf') if args.minimize_score else -float('inf')
best_epoch, n_iter = 0, 0
for epoch in trange(args.epochs):
debug(f'Epoch {epoch}')
n_iter = train(
model=model,
data_loader=train_data_loader,
loss_func=loss_func,
optimizer=optimizer,
scheduler=scheduler,
args=args,
n_iter=n_iter,
atom_bond_scaler=atom_bond_scaler,
logger=logger,
writer=writer
)
if isinstance(scheduler, ExponentialLR):
scheduler.step()
val_scores = evaluate(
model=model,
data_loader=val_data_loader,
num_tasks=args.num_tasks,
metrics=args.metrics,
dataset_type=args.dataset_type,
scaler=scaler,
quantiles=args.quantiles,
atom_bond_scaler=atom_bond_scaler,
logger=logger
)
for metric, scores in val_scores.items():
# Average validation score\
mean_val_score = multitask_mean(
scores=scores,
metric=metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
debug(f'Validation {metric} = {mean_val_score:.6f}')
writer.add_scalar(f'validation_{metric}', mean_val_score, n_iter)
if args.show_individual_scores:
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
task_names = [f"{task_name} lower" for task_name in task_names] + [
f"{task_name} upper" for task_name in task_names]
else:
task_names = args.task_names
# Individual validation scores
for task_name, val_score in zip(task_names, scores):
debug(f'Validation {task_name} {metric} = {val_score:.6f}')
writer.add_scalar(f'validation_{task_name}_{metric}', val_score, n_iter)
# Save model checkpoint if improved validation score
mean_val_score = multitask_mean(
scores=val_scores[args.metric],
metric=args.metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
if args.minimize_score and mean_val_score < best_score or \
not args.minimize_score and mean_val_score > best_score:
best_score, best_epoch = mean_val_score, epoch
save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler, features_scaler,
atom_descriptor_scaler, bond_descriptor_scaler, atom_bond_scaler, args)
# Evaluate on test set using model with best validation score
info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
model = load_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), device=args.device, logger=logger)
if empty_test_set:
info(f'Model {model_idx} provided with no test set, no metric evaluation will be performed.')
else:
test_preds = predict(
model=model,
data_loader=test_data_loader,
scaler=scaler,
atom_bond_scaler=atom_bond_scaler
)
test_scores = evaluate_predictions(
preds=test_preds,
targets=test_targets,
num_tasks=args.num_tasks,
metrics=args.metrics,
dataset_type=args.dataset_type,
is_atom_bond_targets=args.is_atom_bond_targets,
gt_targets=test_data.gt_targets(),
lt_targets=test_data.lt_targets(),
quantiles=args.quantiles,
logger=logger
)
if len(test_preds) != 0:
if args.is_atom_bond_targets:
sum_test_preds += np.array(test_preds, dtype=object)
else:
sum_test_preds += np.array(test_preds)
# Average test score
for metric, scores in test_scores.items():
avg_test_score = np.nanmean(scores)
info(f'Model {model_idx} test {metric} = {avg_test_score:.6f}')
writer.add_scalar(f'test_{metric}', avg_test_score, 0)
if args.show_individual_scores and args.dataset_type != 'spectra':
# Individual test scores
for task_name, test_score in zip(task_names, scores):
info(f'Model {model_idx} test {task_name} {metric} = {test_score:.6f}')
writer.add_scalar(f'test_{task_name}_{metric}', test_score, n_iter)
writer.close()
# Evaluate ensemble on test set
if empty_test_set:
ensemble_scores = {
metric: [np.nan for task in args.task_names] for metric in args.metrics
}
else:
avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()
ensemble_scores = evaluate_predictions(
preds=avg_test_preds,
targets=test_targets,
num_tasks=args.num_tasks,
metrics=args.metrics,
dataset_type=args.dataset_type,
is_atom_bond_targets=args.is_atom_bond_targets,
gt_targets=test_data.gt_targets(),
lt_targets=test_data.lt_targets(),
quantiles=args.quantiles,
logger=logger,
)
for metric, scores in ensemble_scores.items():
# Average ensemble score
mean_ensemble_test_score = multitask_mean(
scores=scores,
metric=metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
info(f'Ensemble test {metric} = {mean_ensemble_test_score:.6f}')
# Individual ensemble scores
if args.show_individual_scores:
for task_name, ensemble_score in zip(task_names, scores):
info(f'Ensemble test {task_name} {metric} = {ensemble_score:.6f}')
# Save scores
with open(os.path.join(args.save_dir, 'test_scores.json'), 'w') as f:
json.dump(ensemble_scores, f, indent=4, sort_keys=True)
# Optionally save test preds
if args.save_preds and not empty_test_set:
test_preds_dataframe = pd.DataFrame(data={'smiles': test_data.smiles()})
if args.is_atom_bond_targets:
n_atoms, n_bonds = test_data.number_of_atoms, test_data.number_of_bonds
for i, atom_target in enumerate(args.atom_targets):
values = np.split(np.array(avg_test_preds[i]).flatten(), np.cumsum(np.array(n_atoms)))[:-1]
values = [list(v) for v in values]
test_preds_dataframe[atom_target] = values
for i, bond_target in enumerate(args.bond_targets):
values = np.split(np.array(avg_test_preds[i+len(args.atom_targets)]).flatten(), np.cumsum(np.array(n_bonds)))[:-1]
values = [list(v) for v in values]
test_preds_dataframe[bond_target] = values
else:
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
avg_test_preds = np.array(avg_test_preds)
num_data = avg_test_preds.shape[0]
preds = avg_test_preds.reshape(num_data, 2, num_tasks).mean(axis=1)
intervals = abs(np.diff(avg_test_preds.reshape(num_data, 2, num_tasks), axis=1) / 2)
intervals = intervals.reshape(num_data, num_tasks)
for i, task_name in enumerate(task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in preds]
for i, task_name in enumerate(task_names):
task_name = f"{task_name}_{args.quantile_loss_alpha}_half_interval"
test_preds_dataframe[task_name] = [interval[i] for interval in intervals]
else:
for i, task_name in enumerate(args.task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in avg_test_preds]
test_preds_dataframe.to_csv(os.path.join(args.save_dir, 'test_preds.csv'), index=False)
return ensemble_scores