# -*- coding: utf-8 -*-
import inspect
from functools import partial, wraps
from typing import Union, Callable, Dict, Sequence, Any, Optional
import jax
import numpy as np
if jax.__version__ >= '0.4.16':
from jax.extend import linear_util
else:
from jax import linear_util
from jax import dtypes, vmap, numpy as jnp, core
from jax._src.api import (_vjp, _jvp)
from jax.api_util import argnums_partial
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_unflatten,
tree_map, tree_transpose,
tree_structure)
from jax.util import safe_map
from brainpy import tools, check
from brainpy._src.math.ndarray import Array, _as_jax_array_
from .tools import (dynvar_deprecation,
node_deprecation,
get_stack_cache,
cache_stack)
from .base import (BrainPyObject, ObjectTransform)
from .variables import (Variable, VariableStack)
from .tools import eval_shape
__all__ = [
'grad', # gradient of scalar function
'vector_grad', # gradient of vector/matrix/...
'functional_vector_grad',
'jacobian', 'jacrev', 'jacfwd', # gradient of jacobian
'hessian', # gradient of hessian
]
class GradientTransform(ObjectTransform):
"""Object-oriented Automatic Differentiation Transformation in BrainPy.
"""
def __init__(
self,
target: Callable,
transform: Callable,
# variables and nodes
grad_vars: Any,
dyn_vars: Dict[str, Variable],
child_objs: Dict[str, Variable],
# gradient setting
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform_setting: Optional[Dict[str, Any]] = None,
# other
name: str = None,
):
super().__init__(name=name)
# gradient variables
self._grad_vars, self._grad_tree = tree_flatten(grad_vars, is_leaf=lambda a: isinstance(a, Array))
# register variables and nodes
self.register_implicit_vars(dyn_vars, self._grad_vars)
self.register_implicit_nodes(child_objs)
# parameters
if argnums is None and len(self._grad_vars) == 0:
argnums = 0
if argnums is None:
assert len(self._grad_vars) > 0
_argnums = 0
elif isinstance(argnums, int):
_argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2)
else:
_argnums = check.is_sequence(argnums, elem_type=int, allow_none=False)
_argnums = tuple(a + 2 for a in _argnums)
if len(self._grad_vars) > 0:
_argnums = (0,) + _argnums
self._nonvar_argnums = argnums
self._argnums = _argnums
self._return_value = return_value
self._has_aux = has_aux
# target
self.target = target
# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
self._grad_setting = dict() if transform_setting is None else transform_setting
if self._has_aux:
self._transform = self._grad_transform(
self._f_grad_with_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
else:
self._transform = self._grad_transform(
self._f_grad_without_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
def _f_grad_with_aux_to_transform(self,
grad_values: tuple,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for v, d in zip(self._grad_vars, grad_values):
v._value = d
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outputs = self.target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0])
return output0, (outputs, [v.value for v in self._grad_vars], self._dyn_vars.dict_data())
def _f_grad_without_aux_to_transform(self,
grad_values: tuple,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for v, d in zip(self._grad_vars, grad_values):
v._value = d
# Users should return the scalar value like this::
# >>> return scalar_loss
output = self.target(*args, **kwargs)
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output)
return output0, (output, [v.value for v in self._grad_vars], self._dyn_vars.dict_data())
def __repr__(self):
name = self.__class__.__name__
f = tools.repr_object(self.target)
f = tools.repr_context(f, " " * (len(name) + 6))
format_ref = (f'{name}({self.name}, target={f}, \n' +
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})')
return format_ref
def _return(self, rets):
grads, (outputs, new_grad_vs, new_dyn_vs) = rets
for v, d in zip(self._grad_vars, new_grad_vs):
v._value = d
for k in new_dyn_vs.keys():
self._dyn_vars[k]._value = new_dyn_vs[k]
# check returned grads
if len(self._grad_vars) > 0:
if self._nonvar_argnums is None:
grads = self._grad_tree.unflatten(grads)
else:
var_grads = self._grad_tree.unflatten(grads[0])
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
grads = (var_grads, arg_grads)
# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads
def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # disable JIT
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)
elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)
rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)
def _make_grad(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
holomorphic: Optional[bool] = False,
allow_int: Optional[bool] = False,
reduce_axes: Optional[Sequence[str]] = (),
has_aux: Optional[bool] = None,
return_value: Optional[bool] = False,
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
return GradientTransform(target=func,
transform=jax.grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes))
[docs]
def grad(
func: Optional[Callable] = None,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
holomorphic: Optional[bool] = False,
allow_int: Optional[bool] = False,
reduce_axes: Optional[Sequence[str]] = (),
has_aux: Optional[bool] = None,
return_value: Optional[bool] = False,
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> Union[Callable, GradientTransform]:
"""Automatic gradient computation for functions or class objects.
This gradient function only support scalar return. It creates a function
which evaluates the gradient of ``func``.
It's worthy to note that the returns are different for different argument settings (where ``arg_grads`` refers
to the gradients of "argnums", and ``var_grads`` refers to the gradients of "grad_vars").
1. When "grad_vars" is None
- "has_aux=False" + "return_value=False" => ``arg_grads``.
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
2. When "grad_vars" is not None and "argnums" is None
- "has_aux=False" + "return_value=False" => ``var_grads``.
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
3. When "grad_vars" is not None and "argnums" is not None
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
Let's see some examples below.
Before start, let's figure out what should be provided as ``grad_vars``?
And, what should be labeled in ``argnums``?
Take the following codes as example:
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> class Example(bp.BrainPyObject):
>>> def __init__(self):
>>> super(Example, self).__init__()
>>> self.x = bm.TrainVar(bm.zeros(1))
>>> self.y = bm.random.rand(10)
>>> def __call__(self, z, v):
>>> t1 = self.x * self.y.sum()
>>> t2 = bm.tanh(z * v + t1)
>>> return t2.mean()
>>>
>>> # This code is equivalent to the following function:
>>>
>>> x = bm.TrainVar(bm.zeros(1))
>>> y = bm.random.rand(10)
>>> def f(z, v):
>>> t1 = x * y.sum()
>>> t2 = bm.tanh(z * v + t1)
>>> return t2.mean()
Generally speaking, all gradient variables which not provided in arguments should be
labeled as ``grad_vars``, while all gradient variables provided in the function arguments
should be declared in ``argnums``.
In above codes, we try to take gradients of ``self.x`` and arguments ``z`` and ``v``, we should
call ``brainpy.math.grad`` as:
>>> f = Example()
>>> f_grad = bm.grad(f, grad_vars=f.x, argnums=(0, 1))
Examples
--------
Grad for a pure function:
>>> import brainpy as bp
>>> grad_tanh = grad(bp.math.tanh)
>>> print(grad_tanh(0.2))
0.961043
Parameters
----------
func : callable, function, BrainPyObject
Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers.
Argument arrays in the positions specified by ``argnums`` must be of
inexact (i.e., floating-point or complex) type. It should return a scalar
(which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.)
grad_vars : optional, ArrayType, sequence of ArrayType, dict
The variables in ``func`` to take their gradients.
argnums : optional, integer or sequence of integers
Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux: optional, bool
Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
return_value : bool
Whether return the loss value.
holomorphic: optional, bool
Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: optional, bool
Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: optional, tuple of int
tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
Returns
-------
func : GradientTransform
A function with the same arguments as ``fun``, that evaluates the gradient
of ``fun``. If ``argnums`` is an integer then the gradient has the same
shape and type as the positional argument indicated by that integer. If
argnums is a tuple of integers, the gradient is a tuple of values with the
same shapes and types as the corresponding arguments. If ``has_aux`` is True
then a pair of (gradient, auxiliary_data) is returned.
"""
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
if func is None:
return lambda f: _make_grad(f,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
has_aux=has_aux,
return_value=return_value)
else:
return _make_grad(func=func,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
has_aux=has_aux,
return_value=return_value)
def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None):
leaves, treedef = tree_flatten(pytree, is_leaf=is_leaf)
axis = axis % arr.ndim
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves]
parts = jnp.split(_as_jax_array_(arr), np.cumsum(safe_map(np.size, leaves[:-1])), axis)
reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)]
return tree_unflatten(treedef, reshaped_parts, )
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(safe_map(np.size, leaves))
dtype = dtypes.result_type(*leaves)
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, flat_basis)
def _isleaf(x):
return isinstance(x, Array)
def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
_check_callable(fun)
@wraps(fun)
def jacfun(*args, **kwargs):
f = linear_util.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if has_aux:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
else:
y, pullback = _vjp(f_partial, *dyn_args, has_aux=False)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_unravel_array_into_pytree, y, 0, is_leaf=_isleaf), jac, is_leaf=_isleaf)
jac = tree_transpose(tree_structure(example_args), tree_flatten(y, is_leaf=_isleaf)[1], jac_tree)
if return_value:
return (jac, y, aux) if has_aux else (jac, y)
else:
return (jac, aux) if has_aux else jac
return jacfun
[docs]
def jacrev(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: Optional[bool] = None,
return_value: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
This function extends the JAX official ``jacrev`` to make automatic jacobian
computation on functions and class functions. Moreover, it supports returning
value ("return_value") and returning auxiliary data ("has_aux").
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
different for different argument settings in ``brainpy.math.jacrev``.
1. When "grad_vars" is None
- "has_aux=False" + "return_value=False" => ``arg_grads``.
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
2. When "grad_vars" is not None and "argnums" is None
- "has_aux=False" + "return_value=False" => ``var_grads``.
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
3. When "grad_vars" is not None and "argnums" is not None
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
Parameters
----------
func: Function whose Jacobian is to be computed.
grad_vars : optional, ArrayType, sequence of ArrayType, dict
The variables in ``func`` to take their gradients.
has_aux: optional, bool
Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
return_value : bool
Whether return the loss value.
argnums: Optional, integer or sequence of integers.
Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be
holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
Returns
-------
fun: GradientTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
return GradientTransform(target=func,
transform=_jacrev,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic,
allow_int=allow_int))
jacobian = jacrev
def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
_check_callable(fun)
@wraps(fun)
def jacfun(*args, **kwargs):
f = linear_util.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if has_aux:
pushfwd = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args))
else:
pushfwd = partial(_jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac = tree_map(partial(_unravel_array_into_pytree, example_args, -1, is_leaf=_isleaf), jac, is_leaf=_isleaf)
if return_value:
return (jac, y, aux) if has_aux else (jac, y)
else:
return (jac, aux) if has_aux else jac
return jacfun
[docs]
def jacfwd(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: Optional[bool] = None,
return_value: bool = False,
holomorphic: bool = False,
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
This function extends the JAX official ``jacfwd`` to make automatic jacobian
computation on functions and class functions. Moreover, it supports returning
value ("return_value") and returning auxiliary data ("has_aux").
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
different for different argument settings in ``brainpy.math.jacfwd``.
1. When "grad_vars" is None
- "has_aux=False" + "return_value=False" => ``arg_grads``.
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
2. When "grad_vars" is not None and "argnums" is None
- "has_aux=False" + "return_value=False" => ``var_grads``.
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
3. When "grad_vars" is not None and "argnums" is not None
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
Parameters
----------
func: Function whose Jacobian is to be computed.
grad_vars : optional, ArrayType, sequence of ArrayType, dict
The variables in ``func`` to take their gradients.
has_aux: optional, bool
Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
return_value : bool
Whether return the loss value.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
Returns
-------
obj: GradientTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
return GradientTransform(target=func,
transform=_jacfwd,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic))
def _functional_hessian(
fun: Callable,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: bool = False,
holomorphic: bool = False,
):
return _jacfwd(
_jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic
)
class GradientTransformPreserveTree(ObjectTransform):
"""
Object-oriented Automatic Differentiation Transformation in BrainPy.
"""
def __init__(
self,
target: Callable,
transform: Callable,
# variables and nodes
grad_vars: Dict[str, Variable],
# gradient setting
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform_setting: Optional[Dict[str, Any]] = None,
# other
name: str = None,
):
super().__init__(name=name)
# gradient variables
if grad_vars is None:
grad_vars = dict()
assert isinstance(grad_vars, dict), 'grad_vars should be a dict'
new_grad_vars = {}
for k, v in grad_vars.items():
assert isinstance(v, Variable)
new_grad_vars[k] = v
self._grad_vars = new_grad_vars
# parameters
if argnums is None and len(self._grad_vars) == 0:
argnums = 0
if argnums is None:
assert len(self._grad_vars) > 0
_argnums = 0
elif isinstance(argnums, int):
_argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2)
else:
_argnums = check.is_sequence(argnums, elem_type=int, allow_none=False)
_argnums = tuple(a + 2 for a in _argnums)
if len(self._grad_vars) > 0:
_argnums = (0,) + _argnums
self._nonvar_argnums = argnums
self._argnums = _argnums
self._return_value = return_value
self._has_aux = has_aux
# target
self.target = target
# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
self._grad_setting = dict() if transform_setting is None else transform_setting
if self._has_aux:
self._transform = self._grad_transform(
self._f_grad_with_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
else:
self._transform = self._grad_transform(
self._f_grad_without_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
def _f_grad_with_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k]._value = v
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outputs = self.target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0])
return output0, (outputs, {k: v for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())
def _f_grad_without_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k].value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k].value = v
# Users should return the scalar value like this::
# >>> return scalar_loss
output = self.target(*args, **kwargs)
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output)
return output0, (output, {k: v.value for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())
def __repr__(self):
name = self.__class__.__name__
f = tools.repr_object(self.target)
f = tools.repr_context(f, " " * (len(name) + 6))
format_ref = (f'{name}({self.name}, target={f}, \n' +
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})')
return format_ref
def _return(self, rets):
grads, (outputs, new_grad_vs, new_dyn_vs) = rets
for k, v in new_grad_vs.items():
self._grad_vars[k].value = v
for k in new_dyn_vs.keys():
self._dyn_vars[k].value = new_dyn_vs[k]
# check returned grads
if len(self._grad_vars) > 0:
if self._nonvar_argnums is None:
pass
else:
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
grads = (grads[0], arg_grads)
# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads
def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # disable JIT
rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)
elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(
self._transform,
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
cache_stack(self.target, stack)
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars.values()])
self._eval_dyn_vars = True
# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)
rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)
[docs]
def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: Optional[bool] = None,
holomorphic=False,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.
Parameters
----------
func : callable, function
Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
containers thereof. It should return arrays, scalars, or standard Python
containers thereof.
grad_vars : optional, ArrayCollector, sequence of ArrayType
The variables required to compute their gradients.
argnums: Optional, integer or sequence of integers
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
holomorphic : bool
Indicates whether ``fun`` is promised to be holomorphic. Default False.
has_aux : bool, optional
Indicates whether ``fun`` returns a pair where the first element is
considered the output of the mathematical function to be differentiated
and the second element is auxiliary data. Default False.
Returns
-------
obj: ObjectTransform
The transformed object.
"""
return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)
def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
_check_callable(func)
@wraps(func)
def grad_fun(*args, **kwargs):
f = linear_util.wrap_init(func, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
if has_aux:
y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
else:
y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
leaves, tree = tree_flatten(y)
tangents = tree_unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
grads = vjp_fn(tangents)
if isinstance(argnums, int):
grads = grads[0]
if has_aux:
return (grads, y, aux) if return_value else (grads, aux)
else:
return (grads, y) if return_value else grads
return grad_fun
[docs]
def vector_grad(
func: Optional[Callable] = None,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: Optional[bool] = None,
# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> Union[Callable, ObjectTransform]:
"""Take vector-valued gradients for function ``func``.
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_,
`brainpy.math.jacrev <./brainpy.math.autograd.jacrev.html>`_ and
`brainpy.math.jacfwd <./brainpy.math.autograd.jacfwd.html>`_,
the returns in this function are different for different argument settings.
1. When "grad_vars" is None
- "has_aux=False" + "return_value=False" => ``arg_grads``.
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
2. When "grad_vars" is not None and "argnums" is None
- "has_aux=False" + "return_value=False" => ``var_grads``.
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
3. When "grad_vars" is not None and "argnums" is not None
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
Parameters
----------
func: Callable
Function whose gradient is to be computed.
grad_vars : optional, ArrayType, sequence of ArrayType, dict
The variables in ``func`` to take their gradients.
has_aux: optional, bool
Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
return_value : bool
Whether return the loss value.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
Returns
-------
func : GradientTransform
The vector gradient function.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
if func is None:
return lambda f: GradientTransform(target=f,
transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux)
else:
return GradientTransform(target=func,
transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
argnums=argnums,
return_value=return_value,
has_aux=False if has_aux is None else has_aux)
def _check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
# is that we can't take weak references to them, which the C++ JIT requires.
if isinstance(fun, staticmethod):
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
if _isgeneratorfunction(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")
def _isgeneratorfunction(fun):
# re-implemented here because of https://bugs.python.org/issue33261
while inspect.ismethod(fun):
fun = fun.__func__
while isinstance(fun, partial):
fun = fun.func
return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
def _check_arg(arg):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
except TypeError:
return core.valid_jaxtype(arg)
else:
return True
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
elif dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving complex "
"outputs, use jax.vjp directly.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For differentiation of functions with integer outputs, use "
"jax.vjp directly.")
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if (dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
elif not dtypes.issubdtype(aval.dtype, np.inexact):
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
def _check_output_dtype_jacfwd(holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
f"dtype, but got {aval.dtype.name}.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError("jacfwd requires real-valued inputs (input dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving "
"complex inputs or integer inputs, use jax.jvp directly.")