chemprop.nn

Contents

chemprop.nn#

Submodules#

Attributes#

Classes#

Aggregation

An Aggregation aggregates the node-level representations of a batch of graphs into

AttentiveAggregation

An Aggregation aggregates the node-level representations of a batch of graphs into

MeanAggregation

Average the graph-level representation:

NormAggregation

Sum the graph-level representation and divide by a normalization constant:

SumAggregation

Sum the graph-level representation:

ConstrainerFFN

A ConstrainerFFN adjusts atom or bond property predictions to satisfy molecular

AtomMessagePassing

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along

BondMessagePassing

A BondMessagePassing encodes a batch of molecular graphs by passing messages along

MABAtomMessagePassing

A MABAtomMessagePassing encodes a batch of molecular graphs by passing messages

MABBondMessagePassing

A MABBondMessagePassing encodes a batch of molecular graphs by passing messages

MABMessagePassing

A MABMessagePassing module encodes a batch of molecular graphs

MessagePassing

A MessagePassing module encodes a batch of molecular graphs

MulticomponentMessagePassing

A MulticomponentMessagePassing performs message-passing on each individual input in a

MAE

Base class for all metrics present in the Metrics API.

MSE

Base class for all metrics present in the Metrics API.

RMSE

Base class for all metrics present in the Metrics API.

SID

Base class for all metrics present in the Metrics API.

BCELoss

Base class for all metrics present in the Metrics API.

BinaryAccuracy

Compute `Accuracy`_ for binary tasks.

BinaryAUPRC

Compute the precision-recall curve for binary tasks.

BinaryAUROC

Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks.

BinaryF1Score

Compute F-1 score for binary tasks.

BinaryMCCLoss

Base class for all metrics present in the Metrics API.

BinaryMCCMetric

Base class for all metrics present in the Metrics API.

BoundedMAE

Base class for all metrics present in the Metrics API.

BoundedMixin

BoundedMSE

Base class for all metrics present in the Metrics API.

BoundedRMSE

Base class for all metrics present in the Metrics API.

ChempropMetric

Base class for all metrics present in the Metrics API.

ClassificationMixin

CrossEntropyLoss

Base class for all metrics present in the Metrics API.

DirichletLoss

Uses the loss function from [sensoy2018] based on the implementation at [sensoyGithub]

EvidentialLoss

Calculate the loss using Eqs. 8, 9, and 10 from [amini2020]. See also [soleimany2021].

MulticlassMCCLoss

Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass

MulticlassMCCMetric

Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass

MVELoss

Calculate the loss using Eq. 9 from [nix1994]

QuantileLoss

Base class for all metrics present in the Metrics API.

R2Score

Compute r2 score also known as `R2 Score_Coefficient Determination`_.

Wasserstein

Base class for all metrics present in the Metrics API.

BinaryClassificationFFN

A _FFNPredictorBase is the base class for all Predictors that use an

BinaryClassificationFFNBase

A _FFNPredictorBase is the base class for all Predictors that use an

BinaryDirichletFFN

A _FFNPredictorBase is the base class for all Predictors that use an

EvidentialFFN

A _FFNPredictorBase is the base class for all Predictors that use an

MulticlassClassificationFFN

A _FFNPredictorBase is the base class for all Predictors that use an

MveFFN

A _FFNPredictorBase is the base class for all Predictors that use an

Predictor

A Predictor is a protocol that defines a differentiable function

QuantileFFN

A _FFNPredictorBase is the base class for all Predictors that use an

RegressionFFN

A _FFNPredictorBase is the base class for all Predictors that use an

SpectralFFN

A _FFNPredictorBase is the base class for all Predictors that use an

GraphTransform

Base class for all neural network modules.

ScaleTransform

Base class for all neural network modules.

UnscaleTransform

Base class for all neural network modules.

Activation

Enum where members are also (and must be) strings

Package Contents#

class chemprop.nn.Aggregation(dim=0, *args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

An Aggregation aggregates the node-level representations of a batch of graphs into a batch of graph-level representations

Note

this class is abstract and cannot be instantiated.

See also

MeanAggregation, SumAggregation, NormAggregation

Parameters:

dim (int)

dim = 0#
hparams#
abstractmethod forward(H, batch)[source]#

Aggregate the graph-level representations of a batch of graphs into their respective global representations

NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length d in the final output.

Parameters:
  • H (Tensor) – a tensor of shape V x d containing the batched node-level representations of b graphs

  • batch (Tensor) – a tensor of shape V containing the index of the graph a given vertex corresponds to

Returns:

a tensor of shape b x d containing the graph-level representations

Return type:

Tensor

chemprop.nn.AggregationRegistry#
class chemprop.nn.AttentiveAggregation(dim=0, *args, output_size, **kwargs)[source]#

Bases: Aggregation

An Aggregation aggregates the node-level representations of a batch of graphs into a batch of graph-level representations

Note

this class is abstract and cannot be instantiated.

See also

MeanAggregation, SumAggregation, NormAggregation

Parameters:
  • dim (int)

  • output_size (int)

W#
forward(H, batch)[source]#

Aggregate the graph-level representations of a batch of graphs into their respective global representations

NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length d in the final output.

Parameters:
  • H (Tensor) – a tensor of shape V x d containing the batched node-level representations of b graphs

  • batch (Tensor) – a tensor of shape V containing the index of the graph a given vertex corresponds to

Returns:

a tensor of shape b x d containing the graph-level representations

Return type:

Tensor

class chemprop.nn.MeanAggregation(dim=0, *args, **kwargs)[source]#

Bases: Aggregation

Average the graph-level representation:

\[\mathbf h = \frac{1}{|V|} \sum_{v \in V} \mathbf h_v\]
Parameters:

dim (int)

forward(H, batch)[source]#

Aggregate the graph-level representations of a batch of graphs into their respective global representations

NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length d in the final output.

Parameters:
  • H (Tensor) – a tensor of shape V x d containing the batched node-level representations of b graphs

  • batch (Tensor) – a tensor of shape V containing the index of the graph a given vertex corresponds to

Returns:

a tensor of shape b x d containing the graph-level representations

Return type:

Tensor

class chemprop.nn.NormAggregation(dim=0, *args, norm=100.0, **kwargs)[source]#

Bases: SumAggregation

Sum the graph-level representation and divide by a normalization constant:

\[\mathbf h = \frac{1}{c} \sum_{v \in V} \mathbf h_v\]
Parameters:
  • dim (int)

  • norm (float)

norm = 100.0#
forward(H, batch)[source]#

Aggregate the graph-level representations of a batch of graphs into their respective global representations

NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length d in the final output.

Parameters:
  • H (Tensor) – a tensor of shape V x d containing the batched node-level representations of b graphs

  • batch (Tensor) – a tensor of shape V containing the index of the graph a given vertex corresponds to

Returns:

a tensor of shape b x d containing the graph-level representations

Return type:

Tensor

class chemprop.nn.SumAggregation(dim=0, *args, **kwargs)[source]#

Bases: Aggregation

Sum the graph-level representation:

\[\mathbf h = \sum_{v \in V} \mathbf h_v\]
Parameters:

dim (int)

forward(H, batch)[source]#

Aggregate the graph-level representations of a batch of graphs into their respective global representations

NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be a zero vector of length d in the final output.

Parameters:
  • H (Tensor) – a tensor of shape V x d containing the batched node-level representations of b graphs

  • batch (Tensor) – a tensor of shape V containing the index of the graph a given vertex corresponds to

Returns:

a tensor of shape b x d containing the graph-level representations

Return type:

Tensor

class chemprop.nn.ConstrainerFFN(n_constraints=1, fp_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu')[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams, lightning.pytorch.core.mixins.HyperparametersMixin

A ConstrainerFFN adjusts atom or bond property predictions to satisfy molecular constraints by using an MLP to map learned atom or bond embeddings to weights that determine how much of the total adjustment needed is added to each atom or bond prediction.

Parameters:
  • n_constraints (int)

  • fp_dim (int)

  • hidden_dim (int)

  • n_layers (int)

  • dropout (float)

  • activation (str)

ffn#
forward(fp, preds, batch, constraints)[source]#

Performs a weighted adjustment to the predictions to satisfy the constraints, with the weights being determined from the learned atom or bond fingerprints via an MLP.

Parameters:
  • fp (Tensor) – a tensor of shape b x h containing the atom or bond-level fingerprints, where b is the number of atoms or bonds and h is the length of each fingerprint.

  • preds (Tensor) – a tensor of shape b x t containing the atom or bond-level predictions, where t is the number of predictions per atom or bond.

  • batch (Tensor) – a tensor of shape b containing indices of which molecule each atom or bond belongs to

  • constraints (Tensor) – a tensor of shape m x t containing the values to which the atom or bond-level predictions should sum to for each molecule, where m is the number of molecules in the batch.

Returns:

a tensor of shape b x t containing the atom or bond-level predictions adjusted to satisfy the molecule-level constraints

Return type:

Tensor

class chemprop.nn.AtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: chemprop.nn.message_passing.mixins._AtomMessagePassingMixin, _MessagePassingBase

A AtomMessagePassing encodes a batch of molecular graphs by passing messages along atoms.

It implements the following operation:

\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.BondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, V_d_transform=None, graph_transform=None)[source]#

Bases: chemprop.nn.message_passing.mixins._BondMessagePassingMixin, _MessagePassingBase

A BondMessagePassing encodes a batch of molecular graphs by passing messages along directed bonds.

It implements the following operation:

\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_o, W_d – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.MABAtomMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#

Bases: chemprop.nn.message_passing.mixins._AtomMessagePassingMixin, _MABMessagePassingBase

A MABAtomMessagePassing encodes a batch of molecular graphs by passing messages along atoms.

It implements the following operation:

\[\begin{split}h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf{W}_i\), \(\mathbf{W}_h\), and \(\mathbf{W}_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_v^{(t)}\) is the hidden representation of atom \(v\) at iteration \(t\); \(m_v^{(t)}\) is the message received by atom \(v\) at iteration \(t\); and \(t \in \{1, \dots, T\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.MABBondMessagePassing(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, bias=False, depth=3, dropout=0.0, activation=Activation.RELU, undirected=False, d_vd=None, d_ed=None, V_d_transform=None, E_d_transform=None, graph_transform=None, return_vertex_embeddings=True, return_edge_embeddings=True)[source]#

Bases: chemprop.nn.message_passing.mixins._BondMessagePassingMixin, _MABMessagePassingBase

A MABBondMessagePassing encodes a batch of molecular graphs by passing messages along directed bonds.

It implements the following operation:

\[\begin{split}h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right),\end{split}\]

where \(\tau\) is the activation function; \(\mathbf W_i\), \(\mathbf W_h\), and \(\mathbf W_o\) are learned weight matrices; \(e_{vw}\) is the feature vector of the bond between atoms \(v\) and \(w\); \(x_v\) is the feature vector of atom \(v\); \(h_{vw}^{(t)}\) is the hidden representation of the bond \(v \rightarrow w\) at iteration \(t\); \(m_{vw}^{(t)}\) is the message received by the bond \(v \to w\) at iteration \(t\); and \(t \in \{1, \dots, T-1\}\) is the number of message passing iterations.

Parameters:
setup(d_v=DEFAULT_ATOM_FDIM, d_e=DEFAULT_BOND_FDIM, d_h=DEFAULT_HIDDEN_DIM, d_vd=None, d_ed=None, bias=False)[source]#

setup the weight matrices used in the message passing update functions

Parameters:
  • d_v (int) – the vertex feature dimension

  • d_e (int) – the edge feature dimension

  • d_h (int, default=300) – the hidden dimension during message passing

  • d_vd (int | None, default=None) – the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout, if any

  • d_ed (int | None, default=None) – the dimension of additional edge descriptors that will be concatenated to the hidden features before readout, if any

  • bias (bool, default=False) – whether to add a learned bias to the matrices

Returns:

W_i, W_h, W_vo, W_vd, W_eo, W_ed – the input, hidden, output, and descriptor weight matrices, respectively, used in the message passing update functions. The descriptor weight matrix is None if no vertex dimension is supplied

Return type:

tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]

class chemprop.nn.MABMessagePassing(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MABMessagePassing module encodes a batch of molecular graphs using message passing to learn both vertex-level and edge-level hidden representations.

Parameters:
  • args (Any)

  • kwargs (Any)

output_dims: tuple[int | None, int | None]#
abstractmethod forward(bmg, V_d=None, E_d=None)[source]#

Encode a batch of molecular graphs.

Parameters:
  • bmg (BatchMolGraph) – the batch of MolGraphs to encode

  • V_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.

  • E_d (Tensor | None, default=None) – an optional tensor of shape E x d_ed containing additional descriptors for each directed edge in the batch. These will be concatenated to the learned edge descriptors and transformed before the readout phase. NOTE: There are two directed edges per graph connection. If the extra descriptors are for the connections, each row should be repeated twice in the tensor, once for each direction, potentially using E_d = np.repeat(E_d, repeats=2, axis=0).

Returns:

Two tensors of shape V x d_h or V x (d_h + d_vd) and E x dh or E x (dh + d_ed) containing the hidden representation of each vertex and edge in the batch of graphs. The feature dimension depends on whether additional atom/bond descriptors were provided. If either the vertex or edge hidden representations are not needed, computing the corresponding tensor can be suppresed by setting either return_vertex_embeddings or return_edge_embeddings to False when initializing the module.

Return type:

tuple[Tensor | None, Tensor | None]

class chemprop.nn.MessagePassing(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MessagePassing module encodes a batch of molecular graphs using message passing to learn vertex-level hidden representations.

Parameters:
  • args (Any)

  • kwargs (Any)

output_dim: int#
abstractmethod forward(bmg, V_d=None)[source]#

Encode a batch of molecular graphs.

Parameters:
  • bmg (BatchMolGraph) – the batch of MolGraphs to encode

  • V_d (Tensor | None, default=None) – an optional tensor of shape V x d_vd containing additional descriptors for each vertex in the batch. These will be concatenated to the learned vertex descriptors and transformed before the readout phase.

Returns:

a tensor of shape V x d_h or V x (d_h + d_vd) containing the hidden representation of each vertex in the batch of graphs. The feature dimension depends on whether additional vertex descriptors were provided

Return type:

Tensor

class chemprop.nn.MulticomponentMessagePassing(blocks, n_components, shared=False)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A MulticomponentMessagePassing performs message-passing on each individual input in a multicomponent input then concatenates the representation of each input to construct a global representation

Parameters:
  • blocks (Sequence[MessagePassing]) – the invidual message-passing blocks for each input

  • n_components (int) – the number of components in each input

  • shared (bool, default=False) – whether one block will be shared among all components in an input. If not, a separate block will be learned for each component.

hparams#
n_components#
shared = False#
blocks#
__len__()[source]#
Return type:

int

property output_dim: int#
Return type:

int

forward(bmgs, V_ds)[source]#

Encode the multicomponent inputs

Parameters:
Returns:

a list of tensors of shape V x d_i containing the respective encodings of the i h component, where d_i is the output dimension of the i h encoder

Return type:

list[Tensor]

class chemprop.nn.MAE(task_weights=1.0)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.MSE(task_weights=1.0)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.RMSE(task_weights=1.0)[source]#

Bases: MSE

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

compute()[source]#
class chemprop.nn.SID(task_weights=1.0, threshold=None, **kwargs)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

  • threshold (float | None)

threshold = None#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

class chemprop.nn.BCELoss(task_weights=1.0)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.BinaryAccuracy(task_weights=1.0, **kwargs)[source]#

Bases: ClassificationMixin, torchmetrics.classification.BinaryAccuracy

Compute `Accuracy`_ for binary tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • acc (Tensor): If multidim_average is set to global, metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

If multidim_average is set to samplewise we expect at least one additional dimension ... to be present, which the reduction will then be applied over instead of the sample dimension N.

Parameters:
  • threshold – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • task_weights (numpy.typing.ArrayLike)

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryAccuracy()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryAccuracy()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryAccuracy(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.3333, 0.1667])
class chemprop.nn.BinaryAUPRC(task_weights=1.0, **kwargs)[source]#

Bases: ClassificationMixin, torchmetrics.classification.BinaryPrecisionRecallCurve

Compute the precision-recall curve for binary tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Tip

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • precision (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned.

  • recall (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned.

  • thresholds (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • thresholds

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • normalization – Specifies a normalization method that is used for batch-wise update regarding negative logits. Set to None if negative logits are desired in evaluation.

  • kwargs – Additional keyword arguments, see Metric kwargs for more info.

  • task_weights (numpy.typing.ArrayLike)

Example

>>> from torchmetrics.classification import BinaryPrecisionRecallCurve
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> bprc = BinaryPrecisionRecallCurve(thresholds=None)
>>> bprc(preds, target)
(tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
 tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
 tensor([0.0000, 0.5000, 0.7000, 0.8000]))
>>> bprc = BinaryPrecisionRecallCurve(thresholds=5)
>>> bprc(preds, target)
(tensor([0.5000, 0.6667, 0.6667, 0.0000,    nan, 1.0000]),
 tensor([1., 1., 1., 0., 0., 0.]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
compute()[source]#

Compute metric.

Return type:

torch.Tensor

class chemprop.nn.BinaryAUROC(task_weights=1.0, **kwargs)[source]#

Bases: ClassificationMixin, torchmetrics.classification.BinaryAUROC

Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

As output to forward and compute the metric returns the following output:

  • b_auroc (Tensor): A single scalar with the auroc score.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • max_fpr – If not None, calculates standardized partial AUC over the range [0, max_fpr].

  • thresholds

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Metric kwargs for more info.

  • task_weights (numpy.typing.ArrayLike)

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAUROC
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryAUROC(thresholds=None)
>>> metric(preds, target)
tensor(0.5000)
>>> b_auroc = BinaryAUROC(thresholds=5)
>>> b_auroc(preds, target)
tensor(0.5000)
class chemprop.nn.BinaryF1Score(task_weights=1.0, **kwargs)[source]#

Bases: ClassificationMixin, torchmetrics.classification.BinaryF1Score

Compute F-1 score for binary tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of zero_division (0 or 1, default is 0) is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • bf1s (Tensor): A tensor whose returned shape depends on the multidim_average argument:

    • If multidim_average is set to global, the metric returns a scalar value.

    • If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

If multidim_average is set to samplewise we expect at least one additional dimension ... to be present, which the reduction will then be applied over instead of the sample dimension N.

Parameters:
  • threshold – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division – Should be 0 or 1. The value returned when \(\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0\).

  • task_weights (numpy.typing.ArrayLike)

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryF1Score()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryF1Score()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryF1Score(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.0000])
class chemprop.nn.BinaryMCCLoss(task_weights=1.0)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

update(preds, targets, mask=None, weights=None, *args)[source]#

Calculate the mean loss function value given predicted and target values

Parameters:
  • preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.

  • targets (Tensor) – a float tensor of shape b x t containing the target values

  • mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation

  • weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight

  • lt_mask (Tensor)

  • gt_mask (Tensor)

compute()[source]#
class chemprop.nn.BinaryMCCMetric(task_weights=1.0)[source]#

Bases: BinaryMCCLoss

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

higher_is_better = True#
compute()[source]#
class chemprop.nn.BoundedMAE(task_weights=1.0)[source]#

Bases: BoundedMixin, MAE

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.BoundedMixin[source]#
class chemprop.nn.BoundedMSE(task_weights=1.0)[source]#

Bases: BoundedMixin, MSE

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.BoundedRMSE(task_weights=1.0)[source]#

Bases: BoundedMixin, RMSE

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.ChempropMetric(task_weights=1.0)[source]#

Bases: torchmetrics.Metric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

is_differentiable = True#
higher_is_better = False#
full_state_update = False#
update(preds, targets, mask=None, weights=None, lt_mask=None, gt_mask=None)[source]#

Calculate the mean loss function value given predicted and target values

Parameters:
  • preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.

  • targets (Tensor) – a float tensor of shape b x t containing the target values

  • mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation

  • weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight

  • lt_mask (Tensor)

  • gt_mask (Tensor)

Return type:

None

compute()[source]#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

class chemprop.nn.ClassificationMixin(task_weights=1.0, **kwargs)[source]#
Parameters:

task_weights (numpy.typing.ArrayLike)

update(preds, targets, mask, *args, **kwargs)[source]#
Parameters:
  • preds (torch.Tensor)

  • targets (torch.Tensor)

  • mask (torch.Tensor)

class chemprop.nn.CrossEntropyLoss(task_weights=1.0)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

class chemprop.nn.DirichletLoss(task_weights=1.0, v_kl=0.2)[source]#

Bases: ChempropMetric

Uses the loss function from [sensoy2018] based on the implementation at [sensoyGithub]

References

[sensoy2018] (1,2)

Sensoy, M.; Kaplan, L.; Kandemir, M. “Evidential deep learning to quantify classification uncertainty.” NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768

Parameters:
  • task_weights (numpy.typing.ArrayLike)

  • v_kl (float)

v_kl = 0.2#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

class chemprop.nn.EvidentialLoss(task_weights=1.0, v_kl=0.2, eps=1e-08)[source]#

Bases: ChempropMetric

Calculate the loss using Eqs. 8, 9, and 10 from [amini2020]. See also [soleimany2021].

References

[amini2020] (1,2)

Amini, A; Schwarting, W.; Soleimany, A.; Rus, D.; “Deep Evidential Regression” Advances in Neural Information Processing Systems; 2020; Vol.33. https://proceedings.neurips.cc/paper_files/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf

[soleimany2021] (1,2)

Soleimany, A.P.; Amini, A.; Goldman, S.; Rus, D.; Bhatia, S.N.; Coley, C.W.; “Evidential Deep Learning for Guided Molecular Property Prediction and Discovery.” ACS Cent. Sci. 2021, 7, 8, 1356-1367. https://doi.org/10.1021/acscentsci.1c00546

Parameters:
  • task_weights (numpy.typing.ArrayLike)

  • v_kl (float)

  • eps (float)

v_kl = 0.2#
eps = 1e-08#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

chemprop.nn.LossFunctionRegistry#
chemprop.nn.MetricRegistry#
class chemprop.nn.MulticlassMCCLoss(task_weights=1.0)[source]#

Bases: ChempropMetric

Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass classification based on the implementataion of [mccSklearn] .. rubric:: References

Parameters:

task_weights (numpy.typing.ArrayLike)

update(preds, targets, mask=None, weights=None, *args)[source]#

Calculate the mean loss function value given predicted and target values

Parameters:
  • preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.

  • targets (Tensor) – a float tensor of shape b x t containing the target values

  • mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation

  • weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight

  • lt_mask (Tensor)

  • gt_mask (Tensor)

compute()[source]#
class chemprop.nn.MulticlassMCCMetric(task_weights=1.0)[source]#

Bases: MulticlassMCCLoss

Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass classification based on the implementataion of [mccSklearn] .. rubric:: References

Parameters:

task_weights (numpy.typing.ArrayLike)

higher_is_better = True#
compute()[source]#
class chemprop.nn.MVELoss(task_weights=1.0)[source]#

Bases: ChempropMetric

Calculate the loss using Eq. 9 from [nix1994]

References

[nix1994] (1,2)

Nix, D. A.; Weigend, A. S. “Estimating the mean and variance of the target probability distribution.” Proceedings of 1994 IEEE International Conference on Neural Networks, 1994 https://doi.org/10.1109/icnn.1994.374138

Parameters:

task_weights (numpy.typing.ArrayLike)

class chemprop.nn.QuantileLoss(task_weights=1.0, alpha=0.1)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

  • alpha (float)

alpha = 0.1#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

class chemprop.nn.R2Score(task_weights=1.0, **kwargs)[source]#

Bases: torchmetrics.R2Score

Compute r2 score also known as `R2 Score_Coefficient Determination`_.

\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]

where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by

\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]

where the parameter \(k\) (the number of independent regressors) should be provided as the adjusted argument. The score is only proper defined when \(SS_{tot}\neq 0\), which can happen for near constant targets. In this case a score of 0 is returned. By definition the score is bounded between \(-inf\) and 1.0, with 1.0 indicating perfect prediction, 0 indicating constant prediction and negative values indicating worse than constant prediction.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model in float tensor with shape (N,) or (N, M) (multioutput)

  • target (Tensor): Ground truth values in float tensor with shape (N,) or (N, M) (multioutput)

As output of forward and compute the metric returns the following output:

  • r2score (Tensor): A tensor with the r2 score(s)

In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.

Parameters:
  • num_outputs – Number of outputs in multioutput setting

  • adjusted – number of independent regressors for calculating adjusted r2 score.

  • multioutput

    Defines aggregation in the case of multiple output scores. Can be one of the following strings:

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

  • kwargs – Additional keyword arguments, see Metric kwargs for more info.

  • task_weights (numpy.typing.ArrayLike)

Warning

Argument num_outputs in R2Score has been deprecated because it is no longer necessary and will be removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape of the input tensors.

Raises:
  • ValueError – If adjusted parameter is not an integer larger or equal to 0.

  • ValueError – If multioutput is not one of "raw_values", "uniform_average" or "variance_weighted".

Parameters:

task_weights (numpy.typing.ArrayLike)

Example (single output):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
Example (multioutput):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
update(preds, targets, mask, *args, **kwargs)[source]#

Update state with predictions and targets.

Parameters:
  • preds (torch.Tensor)

  • targets (torch.Tensor)

  • mask (torch.Tensor)

class chemprop.nn.Wasserstein(task_weights=1.0, threshold=None)[source]#

Bases: ChempropMetric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:
  • kwargs

    additional keyword arguments, see Metric kwargs for more info.

    • compute_on_cpu:

      If metric state should be stored on CPU during computations. Only works for list states.

    • dist_sync_on_step:

      If metric state should synchronize on forward(). Default is False.

    • process_group:

      The process group on which the synchronization is called. Default is the world.

    • dist_sync_fn:

      Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

    • distributed_available_fn:

      Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

    • sync_on_compute:

      If metric state should synchronize when compute is called. Default is True.

    • compute_with_cache:

      If results from compute should be cached. Default is True.

  • task_weights (numpy.typing.ArrayLike)

  • threshold (float | None)

threshold = None#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

class chemprop.nn.BinaryClassificationFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: BinaryClassificationFFNBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 1#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

class chemprop.nn.BinaryClassificationFFNBase(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: _FFNPredictorBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
class chemprop.nn.BinaryDirichletFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: BinaryClassificationFFNBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 2#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

class chemprop.nn.EvidentialFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: RegressionFFN

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 4#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step#
class chemprop.nn.MulticlassClassificationFFN(n_classes, n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: _FFNPredictorBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 1#

the number of targets s to predict for each task t

n_classes#
property n_tasks: int#

the number of tasks t to predict for each input

Return type:

int

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

class chemprop.nn.MveFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: RegressionFFN

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 2#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step#
class chemprop.nn.Predictor(*args, **kwargs)[source]#

Bases: torch.nn.Module, chemprop.nn.hparams.HasHParams

A Predictor is a protocol that defines a differentiable function \(f\) : mathbb R^d mapsto mathbb R^o

Parameters:
  • args (Any)

  • kwargs (Any)

input_dim: int#

the input dimension

output_dim: int#

the output dimension

n_tasks: int#

the number of tasks t to predict for each input

n_targets: int#

the number of targets s to predict for each task t

criterion: chemprop.nn.metrics.ChempropMetric#

the loss function to use for training

task_weights: torch.Tensor#

the weights to apply to each task when calculating the loss

output_transform: chemprop.nn.transforms.UnscaleTransform#

the transform to apply to the output of the predictor

abstractmethod forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

abstractmethod train_step(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

abstractmethod encode(Z, i)[source]#

Calculate the i-th hidden representation

Parameters:
  • Z (Tensor) – a tensor of shape n x d containing the input data to encode, where d is the input dimensionality.

  • i (int) –

    The stop index of slice of the MLP used to encode the input. That is, use all layers in the MLP up to i (i.e., MLP[:i]). This can be any integer value, and the behavior of this function is dependent on the underlying list slicing behavior. For example:

    • i=0: use a 0-layer MLP (i.e., a no-op)

    • i=1: use only the first block

    • i=-1: use up to the final block

Returns:

a tensor of shape n x h containing the i-th hidden representation, where h is the number of neurons in the i-th hidden layer.

Return type:

Tensor

chemprop.nn.PredictorRegistry#
class chemprop.nn.QuantileFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: RegressionFFN

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 2#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step#
class chemprop.nn.RegressionFFN(n_tasks=1, input_dim=DEFAULT_HIDDEN_DIM, hidden_dim=300, n_layers=1, dropout=0.0, activation='relu', criterion=None, task_weights=None, threshold=None, output_transform=None)[source]#

Bases: _FFNPredictorBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:
n_targets = 1#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step#
class chemprop.nn.SpectralFFN(*args, spectral_activation='softplus', **kwargs)[source]#

Bases: _FFNPredictorBase

A _FFNPredictorBase is the base class for all Predictors that use an underlying MLP to map the learned fingerprint to the desired output.

Parameters:

spectral_activation (str | None)

n_targets = 1#

the number of targets s to predict for each task t

forward(Z)[source]#
Parameters:

Z (torch.Tensor)

Return type:

torch.Tensor

train_step#
class chemprop.nn.GraphTransform(V_transform, E_transform)[source]#

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
V_transform#
E_transform#
forward(bmg)[source]#
Parameters:

bmg (chemprop.data.collate.BatchMolGraph)

Return type:

chemprop.data.collate.BatchMolGraph

class chemprop.nn.ScaleTransform(mean, scale, pad=0)[source]#

Bases: _ScaleTransformMixin

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • mean (numpy.typing.ArrayLike)

  • scale (numpy.typing.ArrayLike)

  • pad (int)

forward(X)[source]#
Parameters:

X (torch.Tensor)

Return type:

torch.Tensor

class chemprop.nn.UnscaleTransform(mean, scale, pad=0)[source]#

Bases: _ScaleTransformMixin

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • mean (numpy.typing.ArrayLike)

  • scale (numpy.typing.ArrayLike)

  • pad (int)

forward(X)[source]#
Parameters:

X (torch.Tensor)

Return type:

torch.Tensor

transform_variance(var)[source]#
Parameters:

var (torch.Tensor)

Return type:

torch.Tensor

class chemprop.nn.Activation[source]#

Bases: chemprop.utils.utils.EnumMapping

Enum where members are also (and must be) strings

RELU#
LEAKYRELU#
PRELU#
TANH#
ELU#