Source code for chemprop.utils.v2_0_to_v2_1
from os import PathLike
import pickle
import sys
import torch
[docs]
class Unpickler(pickle.Unpickler):
name_mappings = {
"MSELoss": "MSE",
"MSEMetric": "MSE",
"MAEMetric": "MAE",
"RMSEMetric": "RMSE",
"BoundedMSELoss": "BoundedMSE",
"BoundedMSEMetric": "BoundedMSE",
"BoundedMAEMetric": "BoundedMAE",
"BoundedRMSEMetric": "BoundedRMSE",
"SIDLoss": "SID",
"SIDMetric": "SID",
"WassersteinLoss": "Wasserstein",
"WassersteinMetric": "Wasserstein",
"R2Metric": "R2Score",
"BinaryAUROCMetric": "BinaryAUROC",
"BinaryAUPRCMetric": "BinaryAUPRC",
"BinaryAccuracyMetric": "BinaryAccuracy",
"BinaryF1Metric": "BinaryF1Score",
"BCEMetric": "BCELoss",
}
[docs]
def find_class(self, module, name):
if module == "chemprop.nn.loss":
module = "chemprop.nn.metrics"
name = self.name_mappings.get(name, name)
return super().find_class(module, name)
[docs]
def convert_model_file_v2_0_to_v2_1(model_v1_file: PathLike, model_v2_file: PathLike):
model = torch.load(
model_v1_file, map_location="cpu", pickle_module=sys.modules[__name__], weights_only=False
)
torch.save(model, model_v2_file)
if __name__ == "__main__":
convert_model_file_v2_0_to_v2_1(sys.argv[1], sys.argv[2])