Source code for chemprop.nn.utils

from enum import auto
import logging

from torch import nn

from chemprop.utils.utils import EnumMapping

logger = logging.getLogger(__name__)


[docs] class Activation(EnumMapping): RELU = auto() LEAKYRELU = auto() PRELU = auto() TANH = auto() ELU = auto()
[docs] def get_activation_function(activation: str | nn.Module | Activation) -> nn.Module: """Gets an activation function module given the name of the activation. See :class:`~chemprop.v2.models.utils.Activation` for available activations. Parameters ---------- activation : str | nn.Module | Activation The name of the activation function. Returns ------- nn.Module The activation function module. """ if activation == "selu": logger.warning('Accepting activation="selu" for backward compatibility.') activation = nn.modules.activation.SELU() if isinstance(activation, nn.Module): if isinstance(activation, nn.modules.activation.SELU): logger.warning( "Chemprop does not support self-normalization. Using SELU activation is not enough to achieve it." ) return activation match Activation.get(activation): case Activation.RELU: return nn.ReLU() case Activation.LEAKYRELU: return nn.LeakyReLU(0.1) case Activation.PRELU: return nn.PReLU() case Activation.TANH: return nn.Tanh() case Activation.ELU: return nn.ELU() case _: raise RuntimeError("unreachable code reached!")