Source code for chemprop.nn.agg

from abc import abstractmethod

import torch
from torch import Tensor, nn

from chemprop.nn.hparams import HasHParams
from chemprop.utils import ClassRegistry

__all__ = [
    "Aggregation",
    "AggregationRegistry",
    "MeanAggregation",
    "SumAggregation",
    "NormAggregation",
    "AttentiveAggregation",
]


[docs] class Aggregation(nn.Module, HasHParams): """An :class:`Aggregation` aggregates the node-level representations of a batch of graphs into a batch of graph-level representations .. note:: this class is abstract and cannot be instantiated. See also -------- :class:`~chemprop.v2.models.modules.agg.MeanAggregation` :class:`~chemprop.v2.models.modules.agg.SumAggregation` :class:`~chemprop.v2.models.modules.agg.NormAggregation` """ def __init__(self, dim: int = 0, *args, **kwargs): super().__init__() self.dim = dim self.hparams = {"dim": dim, "cls": self.__class__}
[docs] @abstractmethod def forward(self, H: Tensor, batch: Tensor) -> Tensor: """Aggregate the graph-level representations of a batch of graphs into their respective global representations NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length `d` in the final output. Parameters ---------- H : Tensor a tensor of shape ``V x d`` containing the batched node-level representations of ``b`` graphs batch : Tensor a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to Returns ------- Tensor a tensor of shape ``b x d`` containing the graph-level representations """
AggregationRegistry = ClassRegistry[Aggregation]()
[docs] @AggregationRegistry.register("mean") class MeanAggregation(Aggregation): r"""Average the graph-level representation: .. math:: \mathbf h = \frac{1}{|V|} \sum_{v \in V} \mathbf h_v """
[docs] def forward(self, H: Tensor, batch: Tensor) -> Tensor: index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) dim_size = batch.max().int() + 1 return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( self.dim, index_torch, H, reduce="mean", include_self=False )
[docs] @AggregationRegistry.register("sum") class SumAggregation(Aggregation): r"""Sum the graph-level representation: .. math:: \mathbf h = \sum_{v \in V} \mathbf h_v """
[docs] def forward(self, H: Tensor, batch: Tensor) -> Tensor: index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) dim_size = batch.max().int() + 1 return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( self.dim, index_torch, H, reduce="sum", include_self=False )
[docs] @AggregationRegistry.register("norm") class NormAggregation(SumAggregation): r"""Sum the graph-level representation and divide by a normalization constant: .. math:: \mathbf h = \frac{1}{c} \sum_{v \in V} \mathbf h_v """ def __init__(self, dim: int = 0, *args, norm: float = 100.0, **kwargs): super().__init__(dim, **kwargs) self.norm = norm self.hparams["norm"] = norm
[docs] def forward(self, H: Tensor, batch: Tensor) -> Tensor: return super().forward(H, batch) / self.norm
[docs] class AttentiveAggregation(Aggregation): def __init__(self, dim: int = 0, *args, output_size: int, **kwargs): super().__init__(dim, *args, **kwargs) self.hparams["output_size"] = output_size self.W = nn.Linear(output_size, 1)
[docs] def forward(self, H: Tensor, batch: Tensor) -> Tensor: dim_size = batch.max().int() + 1 attention_logits = self.W(H).exp() Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_( self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False ) alphas = attention_logits / Z[batch] index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( self.dim, index_torch, alphas * H, reduce="sum", include_self=False )