Source code for brainpy.losses.comparison

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module implements several loss functions.
"""

from typing import Tuple, Optional

import braintools.metric as _bt_metric
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map

import brainpy.math as bm
from brainpy.types import ArrayType
from .base import Loss, WeightedLoss
from .utils import _reduce, _multi_return, _is_leaf

__all__ = [
    'CrossEntropyLoss', 'cross_entropy_loss',

    'cross_entropy_sparse',
    'cross_entropy_sigmoid',

    'NLLLoss', 'nll_loss',
    'L1Loss', 'l1_loss',

    'l2_loss',
    'huber_loss',

    'MAELoss', 'mean_absolute_error',
    'MSELoss', 'mean_squared_error',

    'mean_squared_log_error',
    'binary_logistic_loss',
    'multiclass_logistic_loss',
    'sigmoid_binary_cross_entropy',
    'softmax_cross_entropy',
    'log_cosh_loss',
    'ctc_loss_with_forward_probs',
    'ctc_loss',
    'multi_margin_loss',
]


[docs] class CrossEntropyLoss(WeightedLoss): r"""This criterion computes the cross entropy loss between input logits and target. It is useful when training a classification problem with `C` classes. If provided, the optional argument :attr:`weight` should be a 1D `Tensor` assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. The `input` is expected to contain the unnormalized logits for each class (which do `not` need to be positive or sum to 1, in general). `input` has to be a Tensor of size :math:`(C)` for unbatched input, :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the `K`-dimensional case. The last being useful for higher dimension inputs, such as computing cross entropy loss per-pixel for 2D images. The `target` that this criterion expects should contain either: - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if `ignore_index` is specified, this loss also accepts this class index (this index may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss for this case can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as :math:`d_1, ..., d_k` for the `K`-dimensional case. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases} Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and :class:`~torch.nn.NLLLoss`. - Probabilities for each class; useful when labels beyond a single class per minibatch item are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss for this case can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as :math:`d_1, ..., d_k` for the `K`-dimensional case. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases} .. note:: The performance of this criterion is generally better when `target` contains class indices, as this allows for optimized computation. Consider providing `target` as class probabilities only when a single class label per minibatch item is too restrictive. Parameters ---------- weight : Tensor, optional a manual rescaling weight given to each class. If given, has to be a Tensor of size `C` size_average : bool, optional Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored when :attr:`reduce` is ``False``. Default: ``True`` ignore_index : int, optional Specifies a target value that is ignored and does not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. Note that :attr:`ignore_index` is only applicable when the target contains class indices. reduce : bool, optional Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` label_smoothing : float, optional A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`. Shape: - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. where: .. math:: \begin{aligned} C ={} & \text{number of classes} \\ N ={} & \text{batch size} \\ \end{aligned} Examples -------- .. code-block:: python >>> # Example of target with class indices >>> loss = nn.CrossEntropyLoss() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward() >>> >>> # Example of target with class probabilities >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.randn(3, 5).softmax(dim=1) >>> output = loss(input, target) >>> output.backward() """ __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] ignore_index: int label_smoothing: float def __init__(self, weight: Optional[ArrayType] = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: super().__init__(weight, reduction) self.ignore_index = ignore_index self.label_smoothing = label_smoothing
[docs] def update(self, input: ArrayType, target: ArrayType) -> ArrayType: return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction, ignore_index=self.ignore_index, label_smoothing=self.label_smoothing)
[docs] def cross_entropy_loss(predicts, targets, weight=None, reduction='mean', ignore_index=-100, label_smoothing=0.0): r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. It is useful when training a classification problem with `C` classes. If provided, the optional argument :attr:`weight` should be a 1D `Array` assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. The ``input`` is expected to contain raw, unnormalized scores for each class. ``input`` has to be an array of size either :math:`(minibatch, C)` or :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1` for the `K`-dimensional case (described later). This criterion expects a class index in the range :math:`[0, C-1]` as the `target` for each value of a 1D tensor of size `minibatch`. The loss can be described as: .. math:: \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) = -x[class] + \log\left(\sum_j \exp(x[j])\right) or in the case of the :attr:`weight` argument being specified: .. math:: \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) Can also be used for higher dimension inputs, such as 2D images, by providing an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, where :math:`K` is the number of dimensions, and a target of appropriate shape. Parameters ---------- predicts : ArrayType :math:`(N, C)` where `C = number of classes`, or :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` in the case of `K`-dimensional loss. targets : ArrayType :math:`(N, C)` or :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` in the case of K-dimensional loss. weight : ArrayType, optional A manual rescaling weight given to each class. If given, has to be an array of size `C`. reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - ``'none'``: no reduction will be applied, - ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Returns ------- output : scalar, ArrayType If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` in the case of K-dimensional loss. """ def _cel(_pred, _tar): _pred = bm.as_jax(_pred) num_classes = _pred.shape[-1] # Per-sample class weight. The ``weight`` argument is indexed *by the # target class* (``weight[y_n]``), not by the sample position. sample_weight = None # Mask of samples that should contribute to the loss (``ignore_index``). valid_mask = None if bm.ndim(_tar) + 1 == bm.ndim(_pred): # ``_tar`` holds integer class indices. _tar_idx = bm.as_jax(_tar) if weight is not None: sample_weight = bm.as_jax(weight)[_tar_idx] valid_mask = (_tar_idx != ignore_index) # Build the (possibly label-smoothed) soft target distribution. Clamp # ignored indices to 0 first so one_hot does not error on negatives. _tar_clamped = jnp.where(valid_mask, _tar_idx, 0) _soft = bm.as_jax(bm.one_hot(_tar_clamped, num_classes)) if label_smoothing > 0.0: _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes else: # ``_tar`` holds class probabilities / one-hot: the effective per-sample # weight is the probability-weighted class weight (matches PyTorch). _soft = bm.as_jax(_tar) if label_smoothing > 0.0: _soft = _soft * (1.0 - label_smoothing) + label_smoothing / num_classes if weight is not None: sample_weight = (bm.as_jax(weight) * _soft).sum(axis=-1) loss = logsumexp(_pred, axis=-1) - (_pred * _soft).sum(axis=-1) if sample_weight is not None: loss = loss * sample_weight if valid_mask is not None: # Zero-out ignored samples so they contribute nothing to sum/mean. loss = jnp.where(valid_mask, loss, 0.0) if reduction == 'mean': if sample_weight is not None: denom = sample_weight if valid_mask is None else jnp.where(valid_mask, sample_weight, 0.0) return loss.sum() / denom.sum() if valid_mask is not None: return loss.sum() / jnp.maximum(valid_mask.sum(), 1) return loss.mean() return _reduce(outputs=loss, reduction=reduction) r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def cross_entropy_sparse(predicts, targets): r"""Computes the softmax cross-entropy loss. Parameters ---------- predicts (batch, ..., #class) tensor of logits. targets (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. Returns ------- (batch, ...) tensor of the cross-entropy for each entry. """ def crs(_prd, _tar): if isinstance(_tar, int): logits = _prd[..., _tar] else: logits = jnp.take_along_axis(_prd, _tar, -1).squeeze(-1) return logsumexp(bm.as_jax(_prd), axis=-1) - logits r = tree_map(crs, predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def cross_entropy_sigmoid(predicts, targets): """Computes the sigmoid cross-entropy loss. Parameters ---------- predicts (batch, ..., #class) tensor of logits. targets (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) Returns ------- (batch, ...) tensor of the cross-entropies for each entry. """ r = tree_map( lambda pred, tar: bm.as_jax( bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred))) ), predicts, targets, is_leaf=_is_leaf ) return _multi_return(r)
[docs] class NLLLoss(Loss): r"""The negative log likelihood loss. The negative log likelihood loss. It is useful to train a classification problem with `C` classes. If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. The `input` given through a forward call is expected to contain log-probabilities of each class. `input` has to be a Tensor of size either :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for higher dimension inputs, such as computing NLL loss per-pixel for 2D images. Obtaining log-probabilities in a neural network is easily achieved by adding a `LogSoftmax` layer in the last layer of your network. You may use `CrossEntropyLoss` instead, if you prefer not to add an extra layer. The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` where `C = number of classes`; if `ignore_index` is specified, this loss also accepts this class index (this index may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} x_{n,y_n}, \quad w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases} Parameters ---------- reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Otherwise, scalar. """ def __init__(self, reduction: str = 'mean'): super().__init__(reduction=reduction)
[docs] def update(self, input, target): return nll_loss(input, target, reduction=self.reduction)
[docs] def nll_loss(input, target, reduction: str = 'mean'): r""" The negative log likelihood loss. The negative log likelihood loss. It is useful to train a classification problem with `C` classes. If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. The `input` given through a forward call is expected to contain log-probabilities of each class. `input` has to be a Tensor of size either :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for higher dimension inputs, such as computing NLL loss per-pixel for 2D images. Obtaining log-probabilities in a neural network is easily achieved by adding a `LogSoftmax` layer in the last layer of your network. You may use `CrossEntropyLoss` instead, if you prefer not to add an extra layer. The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` where `C = number of classes`; if `ignore_index` is specified, this loss also accepts this class index (this index may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} x_{n,y_n}, \quad w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: \ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases} Parameters ---------- reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Otherwise, scalar. """ assert target.ndim + 1 == input.ndim input = bm.as_jax(input) target = bm.as_jax(target) # Negative log-likelihood: l_n = -x_{n, y_n}. The leading minus sign is what # makes this a *loss* to minimize (the raw log-probabilities are negative). loss = -input[jnp.arange(len(target)), target] if reduction == 'mean': return loss.mean() elif reduction == 'sum': return loss.sum() elif reduction == 'none': return loss elif reduction is None: return loss else: raise ValueError
[docs] class L1Loss(Loss): r"""Creates a criterion that measures the mean absolute error (MAE) between each element in the input :math:`x` and target :math:`y`. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = \left| x_n - y_n \right|, where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then: .. math:: \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total of :math:`n` elements each. The sum operation still operates over all the elements, and divides by :math:`n`. The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. Supports real-valued and complex-valued inputs. Parameters ---------- reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Target: :math:`(*)`, same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. Examples -------- .. code-block:: python >>> loss = nn.L1Loss() >>> input = bm.random.randn(3, 5) >>> target = bm.random.randn(3, 5) >>> output = loss(input, target) >>> output.backward() """ __constants__ = ['reduction'] def __init__(self, reduction: str = 'mean') -> None: super().__init__(reduction=reduction)
[docs] def update(self, input: ArrayType, target: ArrayType) -> ArrayType: return l1_loss(input, target, reduction=self.reduction)
[docs] def l1_loss(logits, targets, reduction='mean'): r"""Creates a criterion that measures the mean absolute error (MAE) between each element in the logits :math:`x` and targets :math:`y`. It is useful in regression problems. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = \left| x_n - y_n \right|, where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then: .. math:: \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total of :math:`n` elements each. The sum operation still operates over all the elements, and divides by :math:`n`. The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. Supports real-valued and complex-valued inputs. Parameters ---------- logits : ArrayType :math:`(N, *)` where :math:`*` means, any number of additional dimensions. targets : ArrayType :math:`(N, *)`, same shape as the input. reduction : str Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. - ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of elements in the output, - ``'sum'``: the output will be summed. Note: :attr:`size_average` Returns ------- output : scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. """ r = tree_map(lambda pred, tar: _bt_metric.l1_loss(pred, tar, reduction=reduction), logits, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def l2_loss(predicts, targets): r"""Computes the L2 loss. The 0.5 term is standard in "Pattern Recognition and Machine Learning" by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. Parameters ---------- predicts : ArrayType A vector of arbitrary shape. targets : ArrayType A vector of shape compatible with predictions. Returns ------- loss : float A scalar value containing the l2 loss. References ---------- .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. """ r = tree_map(lambda pred, tar: _bt_metric.l2_loss(pred, tar), predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] class MAELoss(Loss): def __init__(self, axis=None, reduction: str = 'mean'): super().__init__(reduction=reduction) self.axis = axis
[docs] def update(self, input, target): return mean_absolute_error(input, target, self.axis, reduction=self.reduction)
[docs] def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'): r"""Computes the mean absolute error between x and y. Parameters ---------- x a tensor of shape (d0, .. dN-1). y a tensor of shape (d0, .. dN-1). axis a sequence of the dimensions to keep, use `None` to return a scalar value. Returns ------- tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. """ r = tree_map(lambda a, b: _bt_metric.absolute_error(a, b, axis=axis, reduction=reduction), x, y, is_leaf=_is_leaf) return _multi_return(r)
[docs] class MSELoss(Loss): r"""Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input :math:`x` and target :math:`y`. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = \left( x_n - y_n \right)^2, where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then: .. math:: \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total of :math:`n` elements each. The mean operation still operates over all the elements, and divides by :math:`n`. The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. Parameters ---------- reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Target: :math:`(*)`, same shape as the input. Examples -------- .. code-block:: python >>> loss = nn.MSELoss() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.randn(3, 5) >>> output = loss(input, target) >>> output.backward() """ __constants__ = ['reduction'] def __init__(self, reduction: str = 'mean') -> None: super().__init__(reduction=reduction)
[docs] def update(self, input: ArrayType, target: ArrayType) -> ArrayType: return mean_squared_error(input, target, reduction=self.reduction)
[docs] def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'): r"""Computes the mean squared error between x and y. Parameters ---------- predicts a tensor of shape (d0, .. dN-1). targets a tensor of shape (d0, .. dN-1). axis a sequence of the dimensions to keep, use `None` to return a scalar value. Returns ------- tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. """ r = tree_map(lambda a, b: _bt_metric.squared_error(a, b, axis=axis, reduction=reduction), predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean'): r"""Computes the mean squared logarithmic error between y_true and y_pred. Parameters ---------- targets a tensor of shape (d0, .. dN-1). predicts a tensor of shape (d0, .. dN-1). keep_axis a sequence of the dimensions to keep, use `None` to return a scalar value. Returns ------- tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. """ r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis), predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def huber_loss(predicts, targets, delta: float = 1.0): r"""Huber loss. Huber loss is similar to L2 loss close to zero, L1 loss away from zero. If gradient descent is applied to the `huber loss`, it is equivalent to clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. Parameters ---------- predicts : ArrayType predictions targets : ArrayType ground truth delta : float radius of quadratic behavior Returns ------- loss : float The loss value. References ---------- .. [1] https://en.wikipedia.org/wiki/Huber_loss """ r = tree_map(lambda pred, tar: _bt_metric.huber_loss(pred, tar, delta=delta), predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def binary_logistic_loss(predicts: float, targets: int, ) -> float: """Binary logistic loss. Parameters ---------- targets ground-truth integer label (0 or 1). predicts score produced by the model (float). Returns ------- loss value """ # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), # where xlogx(proba) = proba * log(proba). r = tree_map(lambda a, b: bm.softplus(a) - b * a, predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: """Multiclass logistic loss. Parameters ---------- label : int ground-truth integer label, between 0 and n_classes - 1. logits : jnp.ndarray scores produced by the model, shape = (n_classes, ). Returns ------- loss value """ def loss(pred, tar): pred = bm.as_jax(pred) one_hot = bm.one_hot(tar, pred.shape[0]) return logsumexp(pred) - jnp.dot(pred, one_hot) r = tree_map(loss, logits, label, is_leaf=_is_leaf) return _multi_return(r)
[docs] def sigmoid_binary_cross_entropy(logits, labels): """Computes sigmoid cross entropy given logits and multiple class labels. Measures the probability error in discrete classification tasks in which each class is an independent binary prediction and different classes are not mutually exclusive. This may be used for multilabel image classification for instance a model may predict that an image contains both a cat and a dog. Parameters ---------- logits unnormalized log probabilities. labels the probability for that class. Returns ------- a sigmoid cross entropy loss. References ---------- [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) """ r = tree_map(lambda pred, tar: _bt_metric.sigmoid_binary_cross_entropy(pred, tar), logits, labels, is_leaf=_is_leaf) return _multi_return(r)
[docs] def softmax_cross_entropy(logits, labels): """Computes the softmax cross entropy between sets of logits and labels. Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. Parameters ---------- logits unnormalized log probabilities. labels a valid probability distribution (non-negative, sum to 1), e.g a one hot encoding of which class is the correct one for each input. Returns ------- the cross entropy loss. References ---------- [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) """ r = tree_map(lambda pred, tar: _bt_metric.softmax_cross_entropy(pred, tar), logits, labels, is_leaf=_is_leaf) return _multi_return(r)
[docs] def log_cosh_loss(predicts, targets): r"""Calculates the log-cosh loss for a set of predictions. log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` for large x. It is a twice differentiable alternative to the Huber loss. Parameters ---------- predicts a vector of arbitrary shape. targets a vector of shape compatible with predictions; if not provided then it is assumed to be zero. Returns ------- the log-cosh loss. References ---------- [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) """ r = tree_map(lambda pred, tar: _bt_metric.log_cosh(pred, tar), predicts, targets, is_leaf=_is_leaf) return _multi_return(r)
[docs] def ctc_loss_with_forward_probs( logits: ArrayType, logit_paddings: ArrayType, labels: ArrayType, label_paddings: ArrayType, blank_id: int = 0, log_epsilon: float = -1e5 ) -> Tuple[ArrayType, ArrayType, ArrayType]: r"""Computes CTC loss and CTC forward-probabilities. The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol :math:`\phi` to represent variable-length output sequences. Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows: .. math:: \alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). Here, :math:`\pi` denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of ``labels``. The return values are the logarithms of the above probabilities. Parameters ---------- logits : ArrayType (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in ``logits``, and K denotes the number of classes including a class for blanks. logit_paddings : ArrayType (B, T)-array. Padding indicators for ``logits``. Each element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` denotes that ``logits[b, t, :]`` are padded values. labels : ArrayType (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence. label_paddings : ArrayType (B, N)-array. Padding indicators for ``labels``. Each element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` denotes that ``labels[b, n]`` is a padded label. In the current implementation, ``labels`` must be right-padded, i.e. each row ``labelpaddings[b, :]`` must be repetition of zeroes, followed by repetition of ones. blank_id : int Id for blank token. ``logits[b, :, blank_id]`` are used as probabilities of blank symbols. log_epsilon : float Numerically-stable approximation of log(+0). Returns ------- A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, ``loss_value`` is a (B,)-array containing the loss values for each sequence in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are (T, B, N+1)-arrays where the (t, b, n)-th element denotes \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th sequence in the batch. References ---------- [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) """ return _bt_metric.ctc_loss_with_forward_probs( logits, logit_paddings, labels, label_paddings, blank_id=blank_id, log_epsilon=log_epsilon)
[docs] def ctc_loss(logits: ArrayType, logit_paddings: ArrayType, labels: ArrayType, label_paddings: ArrayType, blank_id: int = 0, log_epsilon: float = -1e5) -> ArrayType: """Computes CTC loss. See docstring for ``ctc_loss_with_forward_probs`` for details. Parameters ---------- logits : ArrayType (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in ``logits``, and K denotes the number of classes including a class for blanks. logit_paddings : ArrayType (B, T)-array. Padding indicators for ``logits``. Each element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` denotes that ``logits[b, t, :]`` are padded values. labels : ArrayType (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence. label_paddings : ArrayType (B, N)-array. Padding indicators for ``labels``. Each element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` denotes that ``labels[b, n]`` is a padded label. In the current implementation, ``labels`` must be right-padded, i.e. each row ``labelpaddings[b, :]`` must be repetition of zeroes, followed by repetition of ones. blank_id : int Id for blank token. ``logits[b, :, blank_id]`` are used as probabilities of blank symbols. log_epsilon : float Numerically-stable approximation of log(+0). Returns ------- (B,)-array containing loss values for each sequence in the batch. """ return _bt_metric.ctc_loss( logits, logit_paddings, labels, label_paddings, blank_id=blank_id, log_epsilon=log_epsilon)
[docs] def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): r"""Computes multi-class margin loss, also called multi-class hinge loss. This loss function is often used in multi-class classification problems. It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin. The loss function for sample :math:`i` is: .. math:: \ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin}) where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class, and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` and :math:`i \neq y`. Parameters ---------- predicts :math:`(N, C)` where `C = number of classes`. target :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. margin : float, optional Has a default value of :math:`1`. p : float, optional Has a default value of :math:`1`. reduction : str, optional Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Returns ------- a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. """ assert p == 1 or p == 2, 'p should be 1 or 2' # Convert to plain JAX arrays: under JAX >= 0.9 implicit __jax_array__ # coercion was removed, so advanced-indexing a ``bm.Array`` would raise. predicts = bm.as_jax(predicts) targets = bm.as_jax(targets) batch_size = predicts.shape[0] correct_scores = predicts[jnp.arange(batch_size), targets] margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p) margins = margins.at[jnp.arange(batch_size), targets].set(0) if reduction == 'mean': return jnp.sum(margins) / batch_size elif reduction == 'sum': return jnp.sum(margins) elif reduction == 'none': return margins