Source code for brainpy._src.math.op_register.ad_support
import functools
from functools import partial
from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad
__all__ = [
'defjvp',
]
[docs]
def defjvp(primitive, *jvp_rules):
"""Define JVP rules for any JAX primitive.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
For examples, please see ``test_ad_support.py``.
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))