Source code for chemprop.nn.transforms

from numpy.typing import ArrayLike
from sklearn.preprocessing import StandardScaler
import torch
from torch import Tensor, nn

from chemprop.data.collate import BatchMolGraph


class _ScaleTransformMixin(nn.Module):
    def __init__(self, mean: ArrayLike, scale: ArrayLike, pad: int = 0):
        super().__init__()

        mean = torch.cat([torch.zeros(pad), torch.tensor(mean, dtype=torch.float)])
        scale = torch.cat([torch.ones(pad), torch.tensor(scale, dtype=torch.float)])

        if mean.shape != scale.shape:
            raise ValueError(
                f"uneven shapes for 'mean' and 'scale'! got: mean={mean.shape}, scale={scale.shape}"
            )

        self.register_buffer("mean", mean.unsqueeze(0))
        self.register_buffer("scale", scale.unsqueeze(0))

    @classmethod
    def from_standard_scaler(cls, scaler: StandardScaler, pad: int = 0):
        return cls(scaler.mean_, scaler.scale_, pad=pad)

    def to_standard_scaler(self, anti_pad: int = 0) -> StandardScaler:
        scaler = StandardScaler()
        scaler.mean_ = self.mean[anti_pad:].numpy()
        scaler.scale_ = self.scale[anti_pad:].numpy()
        return scaler


[docs] class ScaleTransform(_ScaleTransformMixin):
[docs] def forward(self, X: Tensor) -> Tensor: if self.training: return X return (X - self.mean) / self.scale
[docs] class UnscaleTransform(_ScaleTransformMixin):
[docs] def forward(self, X: Tensor) -> Tensor: if self.training: return X return X * self.scale + self.mean
[docs] def transform_variance(self, var: Tensor) -> Tensor: if self.training: return var return var * (self.scale**2)
[docs] class GraphTransform(nn.Module): def __init__(self, V_transform: ScaleTransform, E_transform: ScaleTransform): super().__init__() self.V_transform = V_transform self.E_transform = E_transform
[docs] def forward(self, bmg: BatchMolGraph) -> BatchMolGraph: if self.training: return bmg bmg.V = self.V_transform(bmg.V) bmg.E = self.E_transform(bmg.E) return bmg