from abc import abstractmethod
import torch
from torch import Tensor, nn
from chemprop.nn.hparams import HasHParams
from chemprop.utils import ClassRegistry
__all__ = [
"Aggregation",
"AggregationRegistry",
"MeanAggregation",
"SumAggregation",
"NormAggregation",
"AttentiveAggregation",
]
[docs]
class Aggregation(nn.Module, HasHParams):
"""An :class:`Aggregation` aggregates the node-level representations of a batch of graphs into
a batch of graph-level representations
.. note::
this class is abstract and cannot be instantiated.
See also
--------
:class:`~chemprop.v2.models.modules.agg.MeanAggregation`
:class:`~chemprop.v2.models.modules.agg.SumAggregation`
:class:`~chemprop.v2.models.modules.agg.NormAggregation`
"""
def __init__(self, dim: int = 0, *args, **kwargs):
super().__init__()
self.dim = dim
self.hparams = {"dim": dim, "cls": self.__class__}
[docs]
@abstractmethod
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
"""Aggregate the graph-level representations of a batch of graphs into their respective
global representations
NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be
a zero vector of length `d` in the final output.
Parameters
----------
H : Tensor
a tensor of shape ``V x d`` containing the batched node-level representations of ``b``
graphs
batch : Tensor
a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to
Returns
-------
Tensor
a tensor of shape ``b x d`` containing the graph-level representations
"""
AggregationRegistry = ClassRegistry[Aggregation]()
[docs]
@AggregationRegistry.register("mean")
class MeanAggregation(Aggregation):
r"""Average the graph-level representation:
.. math::
\mathbf h = \frac{1}{|V|} \sum_{v \in V} \mathbf h_v
"""
[docs]
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
dim_size = batch.max().int() + 1
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, H, reduce="mean", include_self=False
)
[docs]
@AggregationRegistry.register("sum")
class SumAggregation(Aggregation):
r"""Sum the graph-level representation:
.. math::
\mathbf h = \sum_{v \in V} \mathbf h_v
"""
[docs]
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
dim_size = batch.max().int() + 1
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, H, reduce="sum", include_self=False
)
[docs]
@AggregationRegistry.register("norm")
class NormAggregation(SumAggregation):
r"""Sum the graph-level representation and divide by a normalization constant:
.. math::
\mathbf h = \frac{1}{c} \sum_{v \in V} \mathbf h_v
"""
def __init__(self, dim: int = 0, *args, norm: float = 100.0, **kwargs):
super().__init__(dim, **kwargs)
self.norm = norm
self.hparams["norm"] = norm
[docs]
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
return super().forward(H, batch) / self.norm
[docs]
class AttentiveAggregation(Aggregation):
def __init__(self, dim: int = 0, *args, output_size: int, **kwargs):
super().__init__(dim, *args, **kwargs)
self.hparams["output_size"] = output_size
self.W = nn.Linear(output_size, 1)
[docs]
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
dim_size = batch.max().int() + 1
attention_logits = self.W(H).exp()
Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
)
alphas = attention_logits / Z[batch]
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, alphas * H, reduce="sum", include_self=False
)