Source code for brainpy.losses.regularization
# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import braintools.metric as _bt_metric
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map
import brainpy.math as bm
from .utils import _is_leaf, _multi_return
__all__ = [
'l2_norm',
'mean_absolute',
'mean_square',
'log_cosh',
'smooth_labels',
]
[docs]
def l2_norm(x, axis=None):
"""Computes the L2 loss.
Parameters
----------
x
n-dimensional tensor of floats.
Returns
-------
scalar tensor containing the l2 loss of x.
"""
leaves, _ = tree_flatten(x)
return jnp.sqrt(jnp.sum(jnp.asarray([jnp.vdot(x, x) for x in leaves]), axis=axis))
[docs]
def mean_absolute(outputs, axis=None):
r"""Computes the mean absolute error between x and y.
Returns
-------
tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error.
"""
r = tree_map(lambda a: _bt_metric.absolute_error(a, None, axis=axis, reduction='mean'),
outputs, is_leaf=_is_leaf)
return _multi_return(r)
[docs]
def mean_square(predicts, axis=None):
r = tree_map(lambda a: _bt_metric.squared_error(a, None, axis=axis, reduction='mean'),
predicts, is_leaf=_is_leaf)
return _multi_return(r)
[docs]
def log_cosh(errors):
r"""Calculates the log-cosh loss for a set of predictions.
log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)`
for large x. It is a twice differentiable alternative to the Huber loss.
Parameters
----------
errors
a vector of arbitrary shape.
Returns
-------
the log-cosh loss.
References
----------
[Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)
"""
r = tree_map(lambda a: _bt_metric.log_cosh(a),
errors, is_leaf=_is_leaf)
return _multi_return(r)
[docs]
def smooth_labels(labels, alpha: float) -> jnp.ndarray:
r"""Apply label smoothing.
Label smoothing is often used in combination with a cross-entropy loss.
Smoothed labels favour small logit gaps, and it has been shown that this can
provide better model calibration by preventing overconfident predictions.
Parameters
----------
labels
one hot labels to be smoothed.
alpha : float
the smoothing factor, the greedy category with be assigned
probability `(1-alpha) + alpha / num_categories`
Returns
-------
a smoothed version of the one hot input labels.
References
----------
[Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)
"""
r = tree_map(lambda tar: _bt_metric.smooth_labels(tar, alpha),
labels, is_leaf=lambda x: isinstance(x, bm.Array))
return _multi_return(r)