XLACustomOp#
- class brainpy.math.XLACustomOp(cpu_kernel=None, gpu_kernel=None, batching_translation=None, jvp_translation=None, transpose_translation=None, outs=None, name=None)[source]#
Creating a XLA custom call operator.
For more information, please refer to the tutorials above: Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html
- Parameters:
cpu_kernel (
Optional
[Callable
]) – Callable. The function defines the computation on CPU backend.gpu_kernel (
Union
[Callable
,str
,None
]) – Callable. The function defines the computation on GPU backend.batching_translation (
Optional
[Callable
]) – Callable. The batching translation rule of JAX.jvp_translation (
Optional
[Callable
]) – Callable. The JVP translation rule of JAX.transpose_translation (
Optional
[Callable
]) – Callable. The transpose translation rule of JAX.outs (
Optional
[Callable
]) – optional. The output information.
- def_abstract_eval(fun)[source]#
Define the abstract evaluation function.
- Parameters:
fun – The abstract evaluation function.
- def_mlir_lowering(platform, fun)[source]#
Define the MLIR lowering rule.
- Parameters:
platform – str. The computing platform.
fun – The lowering rule.