chemprop.nn.metrics#
Attributes#
Classes#
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Compute r2 score also known as `R2 Score_Coefficient Determination`_. |
|
Calculate the loss using Eq. 9 from [nix1994] |
|
Calculate the loss using Eqs. 8, 9, and 10 from [amini2020]. See also [soleimany2021]. |
|
Point-based pinball (quantile) loss operating on one prediction per task. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass |
|
Calculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass |
|
Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. |
|
Compute the precision-recall curve for binary tasks. |
|
Compute `Accuracy`_ for binary tasks. |
|
Compute F-1 score for binary tasks. |
|
Uses the loss function from [sensoy2018] based on the implementation at [sensoyGithub] |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Base class for all metrics present in the Metrics API. |
|
Negative log probability enrichment loss function. |
Module Contents#
- class chemprop.nn.metrics.ChempropMetric(task_weights=1.0)[source]#
Bases:
torchmetrics.MetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- is_differentiable = True#
- higher_is_better = False#
- full_state_update = False#
- update(preds, targets, mask=None, weights=None, lt_mask=None, gt_mask=None)[source]#
Calculate the mean loss function value given predicted and target values
- Parameters:
preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.
targets (Tensor) – a float tensor of shape b x t containing the target values
mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation
weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight
lt_mask (Tensor)
gt_mask (Tensor)
- Return type:
None
- chemprop.nn.metrics.LossFunctionRegistry#
- chemprop.nn.metrics.MetricRegistry#
- class chemprop.nn.metrics.MSE(task_weights=1.0)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.MAE(task_weights=1.0)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.RMSE(task_weights=1.0)[source]#
Bases:
MSEBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.BoundedMSE(task_weights=1.0)[source]#
Bases:
BoundedMixin,MSEBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.BoundedMAE(task_weights=1.0)[source]#
Bases:
BoundedMixin,MAEBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.BoundedRMSE(task_weights=1.0)[source]#
Bases:
BoundedMixin,RMSEBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.R2Score(task_weights=1.0, **kwargs)[source]#
Bases:
torchmetrics.R2ScoreCompute r2 score also known as `R2 Score_Coefficient Determination`_.
\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by
\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]where the parameter \(k\) (the number of independent regressors) should be provided as the adjusted argument. The score is only proper defined when \(SS_{tot}\neq 0\), which can happen for near constant targets. In this case a score of 0 is returned. By definition the score is bounded between \(-inf\) and 1.0, with 1.0 indicating perfect prediction, 0 indicating constant prediction and negative values indicating worse than constant prediction.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): Predictions from model in float tensor with shape(N,)or(N, M)(multioutput)target(Tensor): Ground truth values in float tensor with shape(N,)or(N, M)(multioutput)
As output of
forwardandcomputethe metric returns the following output:r2score(Tensor): A tensor with the r2 score(s)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument
multioutputfor changing this behavior.- Parameters:
num_outputs – Number of outputs in multioutput setting
adjusted – number of independent regressors for calculating adjusted r2 score.
multioutput –
Defines aggregation in the case of multiple output scores. Can be one of the following strings:
'raw_values'returns full set of scores'uniform_average'scores are uniformly averaged'variance_weighted'scores are weighted by their individual variances
kwargs – Additional keyword arguments, see Metric kwargs for more info.
task_weights (numpy.typing.ArrayLike)
Warning
Argument
num_outputsinR2Scorehas been deprecated because it is no longer necessary and will be removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape of the input tensors.- Raises:
ValueError – If
adjustedparameter is not an integer larger or equal to 0.ValueError – If
multioutputis not one of"raw_values","uniform_average"or"variance_weighted".
- Parameters:
task_weights (numpy.typing.ArrayLike)
- Example (single output):
>>> from torch import tensor >>> from torchmetrics.regression import R2Score >>> target = tensor([3, -0.5, 2, 7]) >>> preds = tensor([2.5, 0.0, 2, 8]) >>> r2score = R2Score() >>> r2score(preds, target) tensor(0.9486)
- Example (multioutput):
>>> from torch import tensor >>> from torchmetrics.regression import R2Score >>> target = tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = R2Score(multioutput='raw_values') >>> r2score(preds, target) tensor([0.9654, 0.9082])
- class chemprop.nn.metrics.MVELoss(task_weights=1.0)[source]#
Bases:
ChempropMetricCalculate the loss using Eq. 9 from [nix1994]
References
[nix1994] (1,2)Nix, D. A.; Weigend, A. S. “Estimating the mean and variance of the target probability distribution.” Proceedings of 1994 IEEE International Conference on Neural Networks, 1994 https://doi.org/10.1109/icnn.1994.374138
- Parameters:
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.EvidentialLoss(task_weights=1.0, v_kl=0.2, eps=1e-08)[source]#
Bases:
ChempropMetricCalculate the loss using Eqs. 8, 9, and 10 from [amini2020]. See also [soleimany2021].
References
[amini2020] (1,2)Amini, A; Schwarting, W.; Soleimany, A.; Rus, D.; “Deep Evidential Regression” Advances in Neural Information Processing Systems; 2020; Vol.33. https://proceedings.neurips.cc/paper_files/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf
[soleimany2021] (1,2)Soleimany, A.P.; Amini, A.; Goldman, S.; Rus, D.; Bhatia, S.N.; Coley, C.W.; “Evidential Deep Learning for Guided Molecular Property Prediction and Discovery.” ACS Cent. Sci. 2021, 7, 8, 1356-1367. https://doi.org/10.1021/acscentsci.1c00546
- Parameters:
task_weights (numpy.typing.ArrayLike)
v_kl (float)
eps (float)
- v_kl = 0.2#
- eps = 1e-08#
- class chemprop.nn.metrics.PointQuantileLoss(task_weights=1.0, alpha=0.1)[source]#
Bases:
ChempropMetricPoint-based pinball (quantile) loss operating on one prediction per task. Expects preds and targets shaped [batch, num_tasks]. See [efimov2023]
This is distinct from
QuantileLosswhich uses interval-based predictions (mean + interval, shape [batch, num_tasks, 2]).References
- Parameters:
task_weights (numpy.typing.ArrayLike)
alpha (float)
- alpha = 0.1#
- class chemprop.nn.metrics.BCELoss(task_weights=1.0)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.CrossEntropyLoss(task_weights=1.0)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.BinaryMCCLoss(task_weights=1.0)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- update(preds, targets, mask=None, weights=None, *args)[source]#
Calculate the mean loss function value given predicted and target values
- Parameters:
preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.
targets (Tensor) – a float tensor of shape b x t containing the target values
mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation
weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight
lt_mask (Tensor)
gt_mask (Tensor)
- class chemprop.nn.metrics.BinaryMCCMetric(task_weights=1.0)[source]#
Bases:
BinaryMCCLossBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
- higher_is_better = True#
- class chemprop.nn.metrics.MulticlassMCCLoss(task_weights=1.0)[source]#
Bases:
ChempropMetricCalculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass classification based on the implementataion of [mccSklearn] .. rubric:: References
- Parameters:
task_weights (numpy.typing.ArrayLike)
- update(preds, targets, mask=None, weights=None, *args)[source]#
Calculate the mean loss function value given predicted and target values
- Parameters:
preds (Tensor) – a tensor of shape b x t x u (regression with uncertainty), b x t (regression without uncertainty and binary classification, except for binary dirichlet), or b x t x c (multiclass classification and binary dirichlet) containing the predictions, where b is the batch size, t is the number of tasks to predict, u is the number of values to predict for each task, and c is the number of classes.
targets (Tensor) – a float tensor of shape b x t containing the target values
mask (Tensor) – a boolean tensor of shape b x t indicating whether the given prediction should be included in the loss calculation
weights (Tensor) – a tensor of shape b or b x 1 containing the per-sample weight
lt_mask (Tensor)
gt_mask (Tensor)
- class chemprop.nn.metrics.MulticlassMCCMetric(task_weights=1.0)[source]#
Bases:
MulticlassMCCLossCalculate a soft Matthews correlation coefficient ([mccWiki]) loss for multiclass classification based on the implementataion of [mccSklearn] .. rubric:: References
- Parameters:
task_weights (numpy.typing.ArrayLike)
- higher_is_better = True#
- class chemprop.nn.metrics.ClassificationMixin(task_weights=1.0, **kwargs)[source]#
- Parameters:
task_weights (numpy.typing.ArrayLike)
- class chemprop.nn.metrics.BinaryAUROC(task_weights=1.0, **kwargs)[source]#
Bases:
ClassificationMixin,torchmetrics.classification.BinaryAUROCCompute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): A float tensor of shape(N, ...)containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target(Tensor): An int tensor of shape(N, ...)containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
As output to
forwardandcomputethe metric returns the following output:b_auroc(Tensor): A single scalar with the auroc score.
Additional dimension
...will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
max_fpr – If not
None, calculates standardized partial AUC over the range[0, max_fpr].thresholds –
Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to
Falsefor faster computations.kwargs – Additional keyword arguments, see Metric kwargs for more info.
task_weights (numpy.typing.ArrayLike)
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryAUROC >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryAUROC(thresholds=None) >>> metric(preds, target) tensor(0.5000) >>> b_auroc = BinaryAUROC(thresholds=5) >>> b_auroc(preds, target) tensor(0.5000)
- class chemprop.nn.metrics.BinaryAUPRC(task_weights=1.0, **kwargs)[source]#
Bases:
ClassificationMixin,torchmetrics.classification.BinaryPrecisionRecallCurveCompute the precision-recall curve for binary tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): A float tensor of shape(N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target(Tensor): An int tensor of shape(N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Tip
Additional dimension
...will be flattened into the batch dimension.As output to
forwardandcomputethe metric returns the following output:precision(Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)with precision values is returned.recall(Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)with recall values is returned.thresholds(Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds, )with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size(n_thresholds, )is returned with shared threshold values for all classes.
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
thresholds –
Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to
Falsefor faster computations.normalization – Specifies a normalization method that is used for batch-wise update regarding negative logits. Set to
Noneif negative logits are desired in evaluation.kwargs – Additional keyword arguments, see Metric kwargs for more info.
task_weights (numpy.typing.ArrayLike)
Example
>>> from torchmetrics.classification import BinaryPrecisionRecallCurve >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> bprc = BinaryPrecisionRecallCurve(thresholds=None) >>> bprc(preds, target) (tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]), tensor([0.0000, 0.5000, 0.7000, 0.8000])) >>> bprc = BinaryPrecisionRecallCurve(thresholds=5) >>> bprc(preds, target) (tensor([0.5000, 0.6667, 0.6667, 0.0000, nan, 1.0000]), tensor([1., 1., 1., 0., 0., 0.]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
- class chemprop.nn.metrics.BinaryAccuracy(task_weights=1.0, **kwargs)[source]#
Bases:
ClassificationMixin,torchmetrics.classification.BinaryAccuracyCompute `Accuracy`_ for binary tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): An int or float tensor of shape(N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold.target(Tensor): An int tensor of shape(N, ...)
As output to
forwardandcomputethe metric returns the following output:acc(Tensor): Ifmultidim_averageis set toglobal, metric returns a scalar value. Ifmultidim_averageis set tosamplewise, the metric returns(N,)vector consisting of a scalar value per sample.
If
multidim_averageis set tosamplewisewe expect at least one additional dimension...to be present, which the reduction will then be applied over instead of the sample dimensionN.- Parameters:
threshold – Threshold for transforming probability to binary {0,1} predictions
multidim_average –
Defines how additionally dimensions
...should be handled. Should be one of the following:global: Additional dimensions are flatted along the batch dimensionsamplewise: Statistic will be calculated independently for each sample on theNaxis. The statistics in this case are calculated over the additional dimensions.
ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to
Falsefor faster computations.task_weights (numpy.typing.ArrayLike)
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryAccuracy() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryAccuracy() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryAccuracy(multidim_average='samplewise') >>> metric(preds, target) tensor([0.3333, 0.1667])
- class chemprop.nn.metrics.BinaryF1Score(task_weights=1.0, **kwargs)[source]#
Bases:
ClassificationMixin,torchmetrics.classification.BinaryF1ScoreCompute F-1 score for binary tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of zero_division (0 or 1, default is 0) is returned.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): An int or float tensor of shape(N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold.target(Tensor): An int tensor of shape(N, ...)
As output to
forwardandcomputethe metric returns the following output:bf1s(Tensor): A tensor whose returned shape depends on themultidim_averageargument:If
multidim_averageis set toglobal, the metric returns a scalar value.If
multidim_averageis set tosamplewise, the metric returns(N,)vector consisting of a scalar value per sample.
If
multidim_averageis set tosamplewisewe expect at least one additional dimension...to be present, which the reduction will then be applied over instead of the sample dimensionN.- Parameters:
threshold – Threshold for transforming probability to binary {0,1} predictions
multidim_average –
Defines how additionally dimensions
...should be handled. Should be one of the following:global: Additional dimensions are flatted along the batch dimensionsamplewise: Statistic will be calculated independently for each sample on theNaxis. The statistics in this case are calculated over the additional dimensions.
ignore_index – Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args – bool indicating if input arguments and tensors should be validated for correctness. Set to
Falsefor faster computations.zero_division – Should be 0 or 1. The value returned when \(\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0\).
task_weights (numpy.typing.ArrayLike)
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryF1Score() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryF1Score() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryF1Score(multidim_average='samplewise') >>> metric(preds, target) tensor([0.5000, 0.0000])
- class chemprop.nn.metrics.DirichletLoss(task_weights=1.0, v_kl=0.2)[source]#
Bases:
ChempropMetricUses the loss function from [sensoy2018] based on the implementation at [sensoyGithub]
References
[sensoy2018] (1,2)Sensoy, M.; Kaplan, L.; Kandemir, M. “Evidential deep learning to quantify classification uncertainty.” NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768
- Parameters:
task_weights (numpy.typing.ArrayLike)
v_kl (float)
- v_kl = 0.2#
- class chemprop.nn.metrics.SID(task_weights=1.0, threshold=None, **kwargs)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
threshold (float | None)
- threshold = None#
- class chemprop.nn.metrics.Wasserstein(task_weights=1.0, threshold=None)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
threshold (float | None)
- threshold = None#
- class chemprop.nn.metrics.QuantileLoss(task_weights=1.0, alpha=0.1)[source]#
Bases:
ChempropMetricBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
task_weights (numpy.typing.ArrayLike)
alpha (float)
- alpha = 0.1#
- class chemprop.nn.metrics.NLogProbEnrichment(task_weights=1.0, n1=1, n2=1, method='sqrt', zscale=1.0, zinterval=5.0)[source]#
Bases:
ChempropMetricNegative log probability enrichment loss function. Originally implemented by [lim2022] for DNA-encoded library screening data, but can be applied to any count-based data that can be assumed to follow a Poisson distribution. This code is adapted from [coleyGithub]
Additional arguments, k1, k2, n1 and n2, are needed for the loss function. k1: counts for specific observation in positive sample n2: total counts across observations in positive sample k2: counts for specifc observation in the counter (negative) sample n2: total counts across observations in counter (negative) sample
zinterval: the range of z-scores (+/-) that are used for calculating confidence interval. Defaults to 5 due application on DNA-encoded library screening data.
References
[lim2022]Lim, Katherine S.; Reidenbach, Andrew G.; Hua, Bruce K.; Mason, Jeremy W.; Gerry, Christopher J.; Clemons, Paul A.; Coley, Connor W. “Machine Learning on DNA-Encoded Library Count Data Using an Uncertainty-Aware Probabilistic Loss Function” JCIM, 2022, 62. https://doi.org/10.1021/acs.jcim.2c00041
- Parameters:
task_weights (numpy.typing.ArrayLike)
n1 (int)
n2 (int)
method (Literal['sqrt', 'score', 'wald'])
zscale (float)
zinterval (float)
- n1 = 1#
- n2 = 1#
- method = 'sqrt'#
- zscale = 1.0#
- zinterval = 5.0#