Source code for chemprop.train.train

import logging
from typing import Callable

import numpy as np
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm

from chemprop.args import TrainArgs
from chemprop.data import MoleculeDataLoader, MoleculeDataset, AtomBondScaler
from chemprop.models import MoleculeModel
from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR


[docs]def train( model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, atom_bond_scaler: AtomBondScaler = None, logger: logging.Logger = None, writer: SummaryWriter = None, ) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param atom_bond_scaler: A :class:`~chemprop.data.scaler.AtomBondScaler` fitted on the atomic/bond targets. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if model.is_atom_bond_targets: loss_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), 0 else: loss_sum = iter_count = 0 for batch in tqdm(data_loader, total=len(data_loader), leave=False): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch, constraints_batch, data_weights_batch = \ batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \ batch.atom_features(), batch.bond_descriptors(), batch.bond_features(), batch.constraints(), batch.data_weights() if model.is_atom_bond_targets: targets = [] for dt in zip(*target_batch): dt = np.concatenate(dt) targets.append(torch.tensor([0 if x is None else x for x in dt], dtype=torch.float)) masks = [torch.tensor(mask, dtype=torch.bool) for mask in mask_batch] if args.target_weights is not None: target_weights = [torch.ones(1, 1) * i for i in args.target_weights] # shape(tasks, 1) else: target_weights = [torch.ones(1, 1) for i in targets] data_weights = batch.atom_bond_data_weights() data_weights = [torch.tensor(x).unsqueeze(1) for x in data_weights] natoms, nbonds = batch.number_of_atoms, batch.number_of_bonds natoms, nbonds = np.array(natoms).flatten(), np.array(nbonds).flatten() constraints_batch = np.transpose(constraints_batch).tolist() ind = 0 for i in range(len(args.atom_targets)): if not args.atom_constraints[i]: constraints_batch[ind] = None else: mean, std = atom_bond_scaler.means[ind][0], atom_bond_scaler.stds[ind][0] for j, natom in enumerate(natoms): constraints_batch[ind][j] = (constraints_batch[ind][j] - natom * mean) / std constraints_batch[ind] = torch.tensor(constraints_batch[ind]).to(args.device) ind += 1 for i in range(len(args.bond_targets)): if not args.bond_constraints[i]: constraints_batch[ind] = None else: mean, std = atom_bond_scaler.means[ind][0], atom_bond_scaler.stds[ind][0] for j, nbond in enumerate(nbonds): constraints_batch[ind][j] = (constraints_batch[ind][j] - nbond * mean) / std constraints_batch[ind] = torch.tensor(constraints_batch[ind]).to(args.device) ind += 1 bond_types_batch = [] for i in range(len(args.atom_targets)): bond_types_batch.append(None) for i in range(len(args.bond_targets)): if args.adding_bond_types and atom_bond_scaler is not None: mean, std = atom_bond_scaler.means[i+len(args.atom_targets)][0], atom_bond_scaler.stds[i+len(args.atom_targets)][0] bond_types = [(b.GetBondTypeAsDouble() - mean) / std for d in batch for b in d.mol[0].GetBonds()] bond_types = torch.FloatTensor(bond_types).to(args.device) bond_types_batch.append(bond_types) else: bond_types_batch.append(None) else: mask_batch = np.transpose(mask_batch).tolist() masks = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks) targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks) if args.target_weights is not None: target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks) else: target_weights = torch.ones(targets.shape[1]).unsqueeze(0) data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1) constraints_batch = None bond_types_batch = None if args.loss_function == "bounded_mse": lt_target_batch = batch.lt_targets() # shape(batch, tasks) gt_target_batch = batch.gt_targets() # shape(batch, tasks) lt_target_batch = torch.tensor(lt_target_batch) gt_target_batch = torch.tensor(gt_target_batch) # Run model model.zero_grad() preds = model( mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch, constraints_batch, bond_types_batch, ) # Move tensors to correct device torch_device = args.device if model.is_atom_bond_targets: masks = [x.to(torch_device) for x in masks] masks = [x.reshape([-1, 1]) for x in masks] targets = [x.to(torch_device) for x in targets] targets = [x.reshape([-1, 1]) for x in targets] target_weights = [x.to(torch_device) for x in target_weights] data_weights = [x.to(torch_device) for x in data_weights] else: masks = masks.to(torch_device) targets = targets.to(torch_device) target_weights = target_weights.to(torch_device) data_weights = data_weights.to(torch_device) if args.loss_function == "bounded_mse": lt_target_batch = lt_target_batch.to(torch_device) gt_target_batch = gt_target_batch.to(torch_device) # Calculate losses if model.is_atom_bond_targets: loss_multi_task = [] for target, pred, target_weight, data_weight, mask in zip(targets, preds, target_weights, data_weights, masks): if args.loss_function == "mcc" and args.dataset_type == "classification": loss = loss_func(pred, target, data_weight, mask) * target_weight.squeeze(0) elif args.loss_function == "bounded_mse": raise ValueError(f'Loss function "{args.loss_function}" is not supported with dataset type {args.dataset_type} in atomic/bond properties prediction.') elif args.loss_function in ["binary_cross_entropy", "mse", "mve"]: loss = loss_func(pred, target) * target_weight * data_weight * mask elif args.loss_function == "evidential": loss = loss_func(pred, target, args.evidential_regularization) * target_weight * data_weight * mask elif args.loss_function == "dirichlet" and args.dataset_type == "classification": loss = loss_func(pred, target, args.evidential_regularization) * target_weight * data_weight * mask elif args.loss_function == "quantile_interval": quantiles_tensor = torch.tensor(args.quantiles, device=torch_device) loss = loss_func(pred, target, quantiles_tensor) * target_weight * data_weight * mask else: raise ValueError(f'Dataset type "{args.dataset_type}" is not supported.') loss = loss.sum() / mask.sum() loss_multi_task.append(loss) loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)] iter_count += 1 sum(loss_multi_task).backward() else: if args.loss_function == "mcc" and args.dataset_type == "classification": loss = loss_func(preds, targets, data_weights, masks) * target_weights.squeeze(0) elif args.loss_function == "mcc": # multiclass dataset type targets = targets.long() target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, masks[:, target_index]).unsqueeze(0) target_losses.append(target_loss) loss = torch.cat(target_losses) * target_weights.squeeze(0) elif args.dataset_type == "multiclass": targets = targets.long() if args.loss_function == "dirichlet": loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * masks else: target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) target_losses.append(target_loss) loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * masks elif args.dataset_type == "spectra": loss = loss_func(preds, targets, masks) * target_weights * data_weights * masks elif args.loss_function == "bounded_mse": loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * masks elif args.loss_function == "evidential": loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * masks elif args.loss_function == "dirichlet": # classification loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * masks elif args.loss_function == "quantile_interval": quantiles_tensor = torch.tensor(args.quantiles, device=torch_device) loss = loss_func(preds, targets, quantiles_tensor) * target_weights * data_weights * masks else: loss = loss_func(preds, targets) * target_weights * data_weights * masks if args.loss_function == "mcc": loss = loss.mean() else: loss = loss.sum() / masks.sum() loss_sum += loss.item() iter_count += 1 loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) if model.is_atom_bond_targets: loss_avg = sum(loss_sum) / iter_count loss_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), 0 else: loss_avg = loss_sum / iter_count loss_sum = iter_count = 0 lrs_str = ", ".join(f"lr_{i} = {lr:.4e}" for i, lr in enumerate(lrs)) debug(f"Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}") if writer is not None: writer.add_scalar("train_loss", loss_avg, n_iter) writer.add_scalar("param_norm", pnorm, n_iter) writer.add_scalar("gradient_norm", gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f"learning_rate_{i}", lr, n_iter) return n_iter