XLACustomOp#

class brainpy.math.XLACustomOp(cpu_kernel=None, gpu_kernel=None, batching_translation=None, jvp_translation=None, transpose_translation=None, outs=None, name=None)[source]#

Creating a XLA custom call operator.

For more information, please refer to the tutorials above: Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html

Parameters:
  • cpu_kernel (Optional[Callable]) – Callable. The function defines the computation on CPU backend.

  • gpu_kernel (Union[Callable, str, None]) – Callable. The function defines the computation on GPU backend.

  • batching_translation (Optional[Callable]) – Callable. The batching translation rule of JAX.

  • jvp_translation (Optional[Callable]) – Callable. The JVP translation rule of JAX.

  • transpose_translation (Optional[Callable]) – Callable. The transpose translation rule of JAX.

  • outs (Optional[Callable]) – optional. The output information.

  • name (Optional[str]) – str. The primitive name.

def_abstract_eval(fun)[source]#

Define the abstract evaluation function.

Parameters:

fun – The abstract evaluation function.

def_batching_rule(fun)[source]#

Define the batching rule.

Parameters:

fun – The batching rule.

def_jvp_rule(fun)[source]#

Define the JVP rule.

Parameters:

fun – The JVP rule.

def_mlir_lowering(platform, fun)[source]#

Define the MLIR lowering rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The lowering rule.

def_transpose_rule(fun)[source]#

Define the transpose rule.

Parameters:

fun – The transpose rule.

def_xla_translation(platform, fun)[source]#

Define the XLA translation rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The XLA translation rule.

defjvp(*jvp_rules)[source]#

Define the JVP rule. Similar to jax.interpreters.ad.defjvp, but supports the Primitive with multiple results.

Parameters:

jvp_rules – The JVP rules.