XLACustomOp#

class brainpy.math.XLACustomOp(cpu_kernel=None, gpu_kernel=None, batching_translation=None, jvp_translation=None, transpose_translation=None, outs=None, name=None)[source]#

Creating a XLA custom call operator.

>>> import numba as nb
>>> import taichi as ti
>>> import numpy as np
>>> import jax
>>>
>>> @nb.njit
>>> def numba_cpu_fun(a, b, out_a, out_b):
>>>     out_a[:] = a
>>>     out_b[:] = b
>>>
>>> @ti.kernel
>>>  def taichi_gpu_fun(a, b, out_a, out_b):
>>>    for i in range(a.size):
>>>      out_a[i] = a[i]
>>>    for i in range(b.size):
>>>      out_b[i] = b[i]
>>>
>>> # option 1
>>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun)
>>> a2, b2 = prim(np.random.random(1000), np.random.random(1000),
>>>               outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>>                     jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>>                     outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>>                                                  jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
Parameters:
  • cpu_kernel (Optional[Callable]) – Callable. The function defines the computation on CPU backend.

  • gpu_kernel (Optional[Callable]) – Callable. The function defines the computation on GPU backend.

  • batching_translation (Optional[Callable]) – Callable. The batching translation rule of JAX.

  • jvp_translation (Optional[Callable]) – Callable. The JVP translation rule of JAX.

  • transpose_translation (Optional[Callable]) – Callable. The transpose translation rule of JAX.

  • outs (Optional[Callable]) – optional. The output information.

  • name (Optional[str]) – str. The primitive name.

def_abstract_eval(fun)[source]#

Define the abstract evaluation function.

Parameters:

fun – The abstract evaluation function.

def_batching_rule(fun)[source]#

Define the batching rule.

Parameters:

fun – The batching rule.

def_jvp_rule(fun)[source]#

Define the JVP rule.

Parameters:

fun – The JVP rule.

def_mlir_lowering(platform, fun)[source]#

Define the MLIR lowering rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The lowering rule.

def_transpose_rule(fun)[source]#

Define the transpose rule.

Parameters:

fun – The transpose rule.

def_xla_translation(platform, fun)[source]#

Define the XLA translation rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The XLA translation rule.

defjvp(*jvp_rules)[source]#

Define the JVP rule. Similar to jax.interpreters.ad.defjvp, but supports the Primitive with multiple results.

Parameters:

jvp_rules – The JVP rules.