brainpy.math.defjvp

Contents

brainpy.math.defjvp#

brainpy.math.defjvp(primitive, *jvp_rules)[source]#

Define JVP rules for any JAX primitive.

This function is similar to jax.interpreters.ad.defjvp. However, the JAX one only supports primitive with multiple_results=False. brainpy.math.defjvp enables to define the independent JVP rule for each input parameter no matter multiple_results=False/True.

For examples, please see test_ad_support.py.

Parameters:
  • primitive – Primitive, XLACustomOp.

  • *jvp_rules – The JVP translation rule for each primal.