brainpy.math.defjvp#
- brainpy.math.defjvp(primitive, *jvp_rules)[source]#
Define JVP rules for any JAX primitive.
This function is similar to
jax.interpreters.ad.defjvp
. However, the JAX one only supports primitive withmultiple_results=False
.brainpy.math.defjvp
enables to define the independent JVP rule for each input parameter no mattermultiple_results=False/True
.For examples, please see
test_ad_support.py
.- Parameters:
primitive – Primitive, XLACustomOp.
*jvp_rules – The JVP translation rule for each primal.