jacfwd

Contents

jacfwd#

class brainpy.math.jacfwd(func, grad_vars=None, argnums=None, has_aux=None, return_value=False, holomorphic=False, dyn_vars=None, child_objs=None)[source]#

Extending automatic Jacobian (forward-mode) of func to classes.

This function extends the JAX official jacfwd to make automatic jacobian computation on functions and class functions. Moreover, it supports returning value (“return_value”) and returning auxiliary data (“has_aux”).

Same as brainpy.math.grad, the returns are different for different argument settings in brainpy.math.jacfwd.

  1. When “grad_vars” is None

  • “has_aux=False” + “return_value=False” => arg_grads.

  • “has_aux=True” + “return_value=False” => (arg_grads, aux_data).

  • “has_aux=False” + “return_value=True” => (arg_grads, loss_value).

  • “has_aux=True” + “return_value=True” => (arg_grads, loss_value, aux_data).

  1. When “grad_vars” is not None and “argnums” is None

  • “has_aux=False” + “return_value=False” => var_grads.

  • “has_aux=True” + “return_value=False” => (var_grads, aux_data).

  • “has_aux=False” + “return_value=True” => (var_grads, loss_value).

  • “has_aux=True” + “return_value=True” => (var_grads, loss_value, aux_data).

  1. When “grad_vars” is not None and “argnums” is not None

  • “has_aux=False” + “return_value=False” => (var_grads, arg_grads).

  • “has_aux=True” + “return_value=False” => ((var_grads, arg_grads), aux_data).

  • “has_aux=False” + “return_value=True” => ((var_grads, arg_grads), loss_value).

  • “has_aux=True” + “return_value=True” => ((var_grads, arg_grads), loss_value, aux_data).

Parameters:
  • func (Function whose Jacobian is to be computed.)

  • grad_vars (optional, ArrayType, sequence of ArrayType, dict) – The variables in func to take their gradients.

  • has_aux (optional, bool) – Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • return_value (bool) – Whether return the loss value.

  • argnums (Optional, integer or sequence of integers. Specifies which) – positional argument(s) to differentiate with respect to (default 0).

  • holomorphic (Optional, bool. Indicates whether fun is promised to be) – holomorphic. Default False.

  • dyn_vars (optional, ArrayType, sequence of ArrayType, dict) –

    The dynamically changed variables used in func.

    Deprecated since version 2.4.0: No longer need to provide dyn_vars. This function is capable of automatically collecting the dynamical variables used in the target func.

  • child_objs (optional, BrainPyObject, sequnce, dict) –

    Added in version 2.3.1.

    Deprecated since version 2.4.0: No longer need to provide child_objs. This function is capable of automatically collecting the children objects used in the target func.

Returns:

obj – The transformed object.

Return type:

GradientTransform