import logging
from os import PathLike
from lightning.pytorch import __version__
from lightning.pytorch.utilities.parsing import AttributeDict
import torch
from chemprop.nn.agg import AggregationRegistry
from chemprop.nn.message_passing import (
AtomMessagePassing,
BondMessagePassing,
MulticomponentMessagePassing,
)
from chemprop.nn.metrics import LossFunctionRegistry, MetricRegistry
from chemprop.nn.predictors import PredictorRegistry
from chemprop.nn.transforms import UnscaleTransform
from chemprop.utils import Factory
logger = logging.getLogger(__name__)
[docs]
def max_encoder_index(d: dict) -> int:
return max(
int(k.split(".")[2])
for k in d
if k.startswith("encoder.encoder.") and k.endswith(".W_i.weight")
)
[docs]
def convert_state_dict_v1_to_v2(model_v1_dict: dict) -> dict:
"""Converts v1 model dictionary to a v2 state dictionary"""
state_dict_v2 = {}
args_v1 = model_v1_dict["args"]
state_dict_v1 = model_v1_dict["state_dict"]
str1 = "message_passing.blocks"
str2 = "encoder.encoder"
if "encoder.encoder_solvent.W_i.weight" in state_dict_v1:
state_dict_v2[f"{str1}.0.W_i.weight"] = state_dict_v1[f"{str2}.W_i.weight"]
state_dict_v2[f"{str1}.0.W_h.weight"] = state_dict_v1[f"{str2}.W_h.weight"]
state_dict_v2[f"{str1}.0.W_o.weight"] = state_dict_v1[f"{str2}.W_o.weight"]
state_dict_v2[f"{str1}.0.W_o.bias"] = state_dict_v1[f"{str2}.W_o.bias"]
state_dict_v2[f"{str1}.1.W_i.weight"] = state_dict_v1[f"{str2}_solvent.W_i.weight"]
state_dict_v2[f"{str1}.1.W_h.weight"] = state_dict_v1[f"{str2}_solvent.W_h.weight"]
state_dict_v2[f"{str1}.1.W_o.weight"] = state_dict_v1[f"{str2}_solvent.W_o.weight"]
state_dict_v2[f"{str1}.1.W_o.bias"] = state_dict_v1[f"{str2}_solvent.W_o.bias"]
elif "encoder.encoder.1.W_i.weight" in state_dict_v1:
logger.warning(
"This conversion is untested - please validate your model predictions are consistent after conversion!"
)
i = 0
while i <= max_encoder_index(state_dict_v1):
state_dict_v2[f"{str1}.{i}.W_i.weight"] = state_dict_v1[f"{str2}.{i}.W_i.weight"]
state_dict_v2[f"{str1}.{i}.W_h.weight"] = state_dict_v1[f"{str2}.{i}.W_h.weight"]
state_dict_v2[f"{str1}.{i}.W_o.weight"] = state_dict_v1[f"{str2}.{i}.W_o.weight"]
state_dict_v2[f"{str1}.{i}.W_o.bias"] = state_dict_v1[f"{str2}.{i}.W_o.bias"]
i += 1
else:
state_dict_v2["message_passing.W_i.weight"] = state_dict_v1["encoder.encoder.0.W_i.weight"]
state_dict_v2["message_passing.W_h.weight"] = state_dict_v1["encoder.encoder.0.W_h.weight"]
state_dict_v2["message_passing.W_o.weight"] = state_dict_v1["encoder.encoder.0.W_o.weight"]
state_dict_v2["message_passing.W_o.bias"] = state_dict_v1["encoder.encoder.0.W_o.bias"]
# v1.6 renamed ffn to readout
if "readout.1.weight" in state_dict_v1:
for i in range(args_v1.ffn_num_layers):
suffix = 0 if i == 0 else 2
state_dict_v2[f"predictor.ffn.{i}.{suffix}.weight"] = state_dict_v1[
f"readout.{i * 3 + 1}.weight"
]
state_dict_v2[f"predictor.ffn.{i}.{suffix}.bias"] = state_dict_v1[
f"readout.{i * 3 + 1}.bias"
]
else:
for i in range(args_v1.ffn_num_layers):
suffix = 0 if i == 0 else 2
state_dict_v2[f"predictor.ffn.{i}.{suffix}.weight"] = state_dict_v1[
f"ffn.{i * 3 + 1}.weight"
]
state_dict_v2[f"predictor.ffn.{i}.{suffix}.bias"] = state_dict_v1[
f"ffn.{i * 3 + 1}.bias"
]
if args_v1.dataset_type == "regression":
state_dict_v2["predictor.output_transform.mean"] = torch.tensor(
model_v1_dict["data_scaler"]["means"], dtype=torch.float32
).unsqueeze(0)
state_dict_v2["predictor.output_transform.scale"] = torch.tensor(
model_v1_dict["data_scaler"]["stds"], dtype=torch.float32
).unsqueeze(0)
# target_weights was added in #183
if getattr(args_v1, "target_weights", None) is not None:
task_weights = torch.tensor(args_v1.target_weights).unsqueeze(0)
else:
task_weights = torch.ones(args_v1.num_tasks).unsqueeze(0)
state_dict_v2["predictor.criterion.task_weights"] = task_weights
return state_dict_v2
[docs]
def convert_hyper_parameters_v1_to_v2(model_v1_dict: dict) -> dict:
"""Converts v1 model dictionary to v2 hyper_parameters dictionary"""
hyper_parameters_v2 = {}
renamed_metrics = {
"auc": "roc",
"prc-auc": "prc",
"cross_entropy": "ce",
"binary_cross_entropy": "bce",
"mcc": "binary-mcc",
"recall": "recall is not in v2",
"precision": "precision is not in v2",
"balanced_accuracy": "balanced_accuracy is not in v2",
}
args_v1 = model_v1_dict["args"]
hyper_parameters_v2["batch_norm"] = False
hyper_parameters_v2["metrics"] = [
Factory.build(MetricRegistry[renamed_metrics.get(args_v1.metric, args_v1.metric)])
]
hyper_parameters_v2["warmup_epochs"] = args_v1.warmup_epochs
hyper_parameters_v2["init_lr"] = args_v1.init_lr
hyper_parameters_v2["max_lr"] = args_v1.max_lr
hyper_parameters_v2["final_lr"] = args_v1.final_lr
# convert the message passing block
if getattr(args_v1, "reaction_solvent", False):
W_i_shape = model_v1_dict["state_dict"]["encoder.encoder.W_i.weight"].shape
W_h_shape = model_v1_dict["state_dict"]["encoder.encoder.W_h.weight"].shape
W_o_shape = model_v1_dict["state_dict"]["encoder.encoder.W_o.weight"].shape
d_h = W_i_shape[0]
d_v = W_o_shape[1] - d_h
d_e = W_h_shape[1] - d_h if args_v1.atom_messages else W_i_shape[1] - d_v
reaction = [
AttributeDict(
{
"activation": args_v1.activation,
"bias": args_v1.bias,
"cls": BondMessagePassing if not args_v1.atom_messages else AtomMessagePassing,
"d_e": d_e,
"d_h": args_v1.hidden_size,
"d_v": d_v,
"d_vd": args_v1.atom_descriptors_size,
"depth": args_v1.depth,
"dropout": args_v1.dropout,
"undirected": args_v1.undirected,
}
)
]
W_i_shape = model_v1_dict["state_dict"]["encoder.encoder_solvent.W_i.weight"].shape
W_h_shape = model_v1_dict["state_dict"]["encoder.encoder_solvent.W_h.weight"].shape
W_o_shape = model_v1_dict["state_dict"]["encoder.encoder_solvent.W_o.weight"].shape
d_h = W_i_shape[0]
d_v = W_o_shape[1] - d_h
d_e = W_h_shape[1] - d_h if args_v1.atom_messages else W_i_shape[1] - d_v
solvent = [
AttributeDict(
{
"activation": args_v1.activation,
"bias": args_v1.bias_solvent,
"cls": BondMessagePassing if not args_v1.atom_messages else AtomMessagePassing,
"d_e": d_e,
"d_h": args_v1.hidden_size_solvent,
"d_v": d_v,
"d_vd": args_v1.atom_descriptors_size,
"depth": args_v1.depth_solvent,
"dropout": args_v1.dropout,
"undirected": args_v1.undirected,
}
)
]
hyper_parameters_v2["message_passing"] = AttributeDict(
{
"cls": MulticomponentMessagePassing,
"blocks": reaction + solvent,
"n_components": 2,
"shared": False,
}
)
elif args_v1.number_of_molecules > 1:
logger.warning(
"This conversion is untested - please validate your model predictions are consistent after conversion!"
)
blocks = []
i = 0
while i <= max_encoder_index(model_v1_dict["state_dict"]):
W_i_shape = model_v1_dict["state_dict"][f"encoder.encoder.{i}.W_i.weight"].shape
W_h_shape = model_v1_dict["state_dict"][f"encoder.encoder.{i}.W_h.weight"].shape
W_o_shape = model_v1_dict["state_dict"][f"encoder.encoder.{i}.W_o.weight"].shape
d_h = W_i_shape[0]
d_v = W_o_shape[1] - d_h
d_e = W_h_shape[1] - d_h if args_v1.atom_messages else W_i_shape[1] - d_v
blocks.append(
AttributeDict(
{
"activation": args_v1.activation,
"bias": args_v1.bias,
"cls": BondMessagePassing
if not args_v1.atom_messages
else AtomMessagePassing,
"d_e": d_e,
"d_h": args_v1.hidden_size,
"d_v": d_v,
"d_vd": args_v1.atom_descriptors_size,
"depth": args_v1.depth,
"dropout": args_v1.dropout,
"undirected": args_v1.undirected,
}
)
)
i += 1
hyper_parameters_v2["message_passing"] = AttributeDict(
{
"cls": MulticomponentMessagePassing,
"blocks": blocks,
"n_components": args_v1.number_of_molecules,
"shared": args_v1.mpn_shared,
}
)
else:
W_i_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_i.weight"].shape
W_h_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_h.weight"].shape
W_o_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_o.weight"].shape
d_h = W_i_shape[0]
d_v = W_o_shape[1] - d_h
d_e = W_h_shape[1] - d_h if args_v1.atom_messages else W_i_shape[1] - d_v
hyper_parameters_v2["message_passing"] = AttributeDict(
{
"activation": args_v1.activation,
"bias": args_v1.bias,
"cls": BondMessagePassing if not args_v1.atom_messages else AtomMessagePassing,
"d_e": d_e, # the feature dimension of the edges
"d_h": args_v1.hidden_size, # dimension of the hidden layer
"d_v": d_v, # the feature dimension of the vertices
"d_vd": args_v1.atom_descriptors_size,
"depth": args_v1.depth,
"dropout": args_v1.dropout,
"undirected": args_v1.undirected,
}
)
# convert the aggregation block
hyper_parameters_v2["agg"] = {
"dim": 0, # in v1, the aggregation is always done on the atom features
"cls": AggregationRegistry[args_v1.aggregation],
}
if args_v1.aggregation == "norm":
hyper_parameters_v2["agg"]["norm"] = args_v1.aggregation_norm
# convert the predictor block
fgs = args_v1.features_generator or []
d_xd = sum((200 if "rdkit" in fg else 0) + (2048 if "morgan" in fg else 0) for fg in fgs)
if getattr(args_v1, "target_weights", None) is not None:
task_weights = torch.tensor(args_v1.target_weights).unsqueeze(0)
else:
task_weights = torch.ones(args_v1.num_tasks).unsqueeze(0)
# loss_function was added in #238
loss_fn_defaults = {
"classification": "bce",
"regression": "mse",
"multiclass": "ce",
"specitra": "sid",
}
str_loss_fn = getattr(args_v1, "loss_function", loss_fn_defaults[args_v1.dataset_type])
T_loss_fn = LossFunctionRegistry[renamed_metrics.get(str_loss_fn, str_loss_fn)]
hyper_parameters_v2["predictor"] = AttributeDict(
{
"activation": args_v1.activation,
"cls": PredictorRegistry[args_v1.dataset_type],
"criterion": Factory.build(T_loss_fn, task_weights=task_weights),
"task_weights": None,
"dropout": args_v1.dropout,
"hidden_dim": args_v1.ffn_hidden_size,
"input_dim": (
(args_v1.hidden_size + args_v1.hidden_size_solvent)
if getattr(args_v1, "reaction_solvent", False)
else (args_v1.hidden_size * getattr(args_v1, "number_of_molecules", 1))
)
+ args_v1.atom_descriptors_size
+ d_xd,
"n_layers": args_v1.ffn_num_layers - 1,
"n_tasks": args_v1.num_tasks,
}
)
if args_v1.dataset_type == "regression":
hyper_parameters_v2["predictor"]["output_transform"] = UnscaleTransform(
model_v1_dict["data_scaler"]["means"], model_v1_dict["data_scaler"]["stds"]
)
return hyper_parameters_v2
[docs]
def convert_model_dict_v1_to_v2(model_v1_dict: dict) -> dict:
"""Converts a v1 model dictionary from a loaded .pt file to a v2 model dictionary"""
model_v2_dict = {}
model_v2_dict["epoch"] = None
model_v2_dict["global_step"] = None
model_v2_dict["pytorch-lightning_version"] = __version__
model_v2_dict["state_dict"] = convert_state_dict_v1_to_v2(model_v1_dict)
model_v2_dict["loops"] = None
model_v2_dict["callbacks"] = None
model_v2_dict["optimizer_states"] = None
model_v2_dict["lr_schedulers"] = None
model_v2_dict["hparams_name"] = "kwargs"
model_v2_dict["hyper_parameters"] = convert_hyper_parameters_v1_to_v2(model_v1_dict)
return model_v2_dict
[docs]
def convert_model_file_v1_to_v2(model_v1_file: PathLike, model_v2_file: PathLike) -> None:
"""Converts a v1 model .pt file to a v2 model .pt file"""
model_v1_dict = torch.load(model_v1_file, map_location=torch.device("cpu"), weights_only=False)
model_v2_dict = convert_model_dict_v1_to_v2(model_v1_dict)
logger.warning(
"Remember to use the same featurizers which were used when training the model. The default "
"v1 atom featurizer is `chemprop.featurizers.atom.MultiHotAtomFeaturizer.v1()` and can be "
"specified from the command line with `--multi-hot-atom-featurizer-mode v1`."
)
torch.save(model_v2_dict, model_v2_file)