Source code for brainpy.math.object_transform.autograd

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Callable, Dict, Sequence, Optional

import brainstate.transform

from ._utils import warp_to_no_state_input_output
from .variables import Variable

__all__ = [
    'grad',  # gradient of scalar function
    'vector_grad',  # gradient of vector/matrix/...
    'functional_vector_grad',
    'jacobian', 'jacrev', 'jacfwd',  # gradient of jacobian
    'hessian',  # gradient of hessian
]


[docs] def grad( func: Optional[Callable] = None, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, holomorphic: Optional[bool] = False, allow_int: Optional[bool] = False, has_aux: Optional[bool] = None, return_value: Optional[bool] = False, ) -> Union[Callable, Callable[..., Callable]]: """Automatic gradient computation for functions or class objects. This gradient function only support scalar return. It creates a function which evaluates the gradient of ``func``. It's worthy to note that the returns are different for different argument settings (where ``arg_grads`` refers to the gradients of "argnums", and ``var_grads`` refers to the gradients of "grad_vars"). 1. When "grad_vars" is None - "has_aux=False" + "return_value=False" => ``arg_grads``. - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. 2. When "grad_vars" is not None and "argnums" is None - "has_aux=False" + "return_value=False" => ``var_grads``. - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. 3. When "grad_vars" is not None and "argnums" is not None - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. Let's see some examples below. Before start, let's figure out what should be provided as ``grad_vars``? And, what should be labeled in ``argnums``? Take the following codes as example: >>> import brainpy as bp >>> import brainpy.math as bm >>> >>> class Example(bp.BrainPyObject): >>> def __init__(self): >>> super(Example, self).__init__() >>> self.x = bm.TrainVar(bm.zeros(1)) >>> self.y = bm.random.rand(10) >>> def __call__(self, z, v): >>> t1 = self.x * self.y.sum() >>> t2 = bm.tanh(z * v + t1) >>> return t2.mean() >>> >>> # This code is equivalent to the following function: >>> >>> x = bm.TrainVar(bm.zeros(1)) >>> y = bm.random.rand(10) >>> def f(z, v): >>> t1 = x * y.sum() >>> t2 = bm.tanh(z * v + t1) >>> return t2.mean() Generally speaking, all gradient variables which not provided in arguments should be labeled as ``grad_vars``, while all gradient variables provided in the function arguments should be declared in ``argnums``. In above codes, we try to take gradients of ``self.x`` and arguments ``z`` and ``v``, we should call ``brainpy.math.grad`` as: >>> f = Example() >>> f_grad = bm.grad(f, grad_vars=f.x, argnums=(0, 1)) Examples:: Grad for a pure function: >>> import brainpy as bp >>> grad_tanh = grad(bp.math.tanh) >>> print(grad_tanh(0.2)) 0.961043 Parameters:: func : callable, function, BrainPyObject Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. argnums : optional, integer or sequence of integers Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. holomorphic: optional, bool Indicates whether ``fun`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. allow_int: optional, bool Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. Returns:: func : GradientTransform A function with the same arguments as ``fun``, that evaluates the gradient of ``fun``. If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If ``has_aux`` is True then a pair of (gradient, auxiliary_data) is returned. """ if func is None: return lambda f: grad(f, grad_vars=grad_vars, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=has_aux, return_value=return_value) else: return brainstate.transform.grad( warp_to_no_state_input_output(func), grad_states=grad_vars, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=has_aux, return_value=return_value, check_states=False, )
[docs] def jacrev( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: Optional[bool] = None, return_value: bool = False, holomorphic: bool = False, allow_int: bool = False, ): """Extending automatic Jacobian (reverse-mode) of ``func`` to classes. This function extends the JAX official ``jacrev`` to make automatic jacobian computation on functions and class functions. Moreover, it supports returning value ("return_value") and returning auxiliary data ("has_aux"). Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are different for different argument settings in ``brainpy.math.jacrev``. 1. When "grad_vars" is None - "has_aux=False" + "return_value=False" => ``arg_grads``. - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. 2. When "grad_vars" is not None and "argnums" is None - "has_aux=False" + "return_value=False" => ``var_grads``. - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. 3. When "grad_vars" is not None and "argnums" is not None - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. Parameters:: func: Function whose Jacobian is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. Returns:: fun: GradientTransform The transformed object. """ return brainstate.transform.jacrev( warp_to_no_state_input_output(func), grad_states=grad_vars, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=has_aux, return_value=return_value, check_states=False, )
jacobian = jacrev
[docs] def jacfwd( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: Optional[bool] = None, return_value: bool = False, holomorphic: bool = False, ): """Extending automatic Jacobian (forward-mode) of ``func`` to classes. This function extends the JAX official ``jacfwd`` to make automatic jacobian computation on functions and class functions. Moreover, it supports returning value ("return_value") and returning auxiliary data ("has_aux"). Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are different for different argument settings in ``brainpy.math.jacfwd``. 1. When "grad_vars" is None - "has_aux=False" + "return_value=False" => ``arg_grads``. - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. 2. When "grad_vars" is not None and "argnums" is None - "has_aux=False" + "return_value=False" => ``var_grads``. - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. 3. When "grad_vars" is not None and "argnums" is not None - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. Parameters:: func: Function whose Jacobian is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. Returns:: obj: GradientTransform The transformed object. """ return brainstate.transform.jacfwd( warp_to_no_state_input_output(func), grad_states=grad_vars, argnums=argnums, holomorphic=holomorphic, has_aux=has_aux, return_value=return_value, check_states=False, )
[docs] def hessian( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, has_aux: Optional[bool] = None, holomorphic=False, ): """Hessian of ``func`` as a dense array. Parameters:: func : callable, function Function whose Hessian is to be computed. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof. grad_vars : optional, ArrayCollector, sequence of ArrayType The variables required to compute their gradients. argnums: Optional, integer or sequence of integers Specifies which positional argument(s) to differentiate with respect to (default ``0``). holomorphic : bool Indicates whether ``fun`` is promised to be holomorphic. Default False. has_aux : bool, optional Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. Returns:: obj: ObjectTransform The transformed object. """ return brainstate.transform.hessian( warp_to_no_state_input_output(func), grad_states=grad_vars, argnums=argnums, holomorphic=holomorphic, has_aux=has_aux, check_states=False, )
[docs] def vector_grad( func: Optional[Callable] = None, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, return_value: bool = False, has_aux: Optional[bool] = None, ) -> Callable: """Take vector-valued gradients for function ``func``. Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, `brainpy.math.jacrev <./brainpy.math.autograd.jacrev.html>`_ and `brainpy.math.jacfwd <./brainpy.math.autograd.jacfwd.html>`_, the returns in this function are different for different argument settings. 1. When "grad_vars" is None - "has_aux=False" + "return_value=False" => ``arg_grads``. - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``. 2. When "grad_vars" is not None and "argnums" is None - "has_aux=False" + "return_value=False" => ``var_grads``. - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``. - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``. - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``. 3. When "grad_vars" is not None and "argnums" is not None - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``. - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``. - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. Parameters:: func: Callable Function whose gradient is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. has_aux: optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). Returns:: func : GradientTransform The vector gradient function. """ return brainstate.transform.vector_grad( warp_to_no_state_input_output(func), grad_states=grad_vars, argnums=argnums, return_value=return_value, has_aux=has_aux, check_states=False, )
functional_vector_grad = vector_grad