Saving and loading models#
[1]:
import torch
from chemprop.models.utils import save_model, load_model
from chemprop.models.model import MPNN
from chemprop.models.multi import MulticomponentMPNN
from chemprop import nn
This is an example buffer to save to and load from, to avoid creating new files when running this notebook. A real use case would probably save to and read from a file like model.pt.
[2]:
import io
saved_model = io.BytesIO()
# from pathlib import Path
# saved_model = Path("model.pt")
Saving models#
A valid model save file is a dictionary containing the hyper parameters and state dict of the model. torch is used to pickle the dictionary.
[3]:
model = MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())
save_model(saved_model, model)
# model_dict = {"hyper_parameters": model.hparams, "state_dict": model.state_dict()}
# torch.save(model_dict, saved_model)
lightning will also automatically create checkpoint files during training. These .ckpt files are like .pt model files, but also contain information about training and can be used to restart training. See the lightning documentation for more details.
[4]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
checkpointing = ModelCheckpoint(
dirpath="mycheckpoints",
filename="best-{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_last=True,
)
trainer = Trainer(callbacks=[checkpointing])
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading models#
MPNN and MulticomponentMPNN each have a class method to load a model from either a model file .pt or a checkpoint file .ckpt. The method to load from a file works for either model files or checkpoint files, but won’t load the saved training information from a checkpoint file.
[5]:
# Need to set the buffer stream position to the beginning, not necessary if using a file
saved_model.seek(0)
model = MPNN.load_from_file(saved_model)
# Other options
# model = MPNN.load_from_checkpoint(saved_model)
# model = MulticomponentMPNN.load_from_file(saved_model)
# model = MulticomponentMPNN.load_from_checkpoint(saved_model)