Source code for brainpy._src.math.op_register.base

from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

import jax
import numpy as np
from jax.interpreters import xla, batching, ad, mlir

from brainpy._src.dependency_check import import_numba, import_cupy_jit
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject

if jax.__version__ >= '0.4.16':
  from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
  from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
                                 register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
  from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
                            register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
  from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
  from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
                                 register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
  from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
                            register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

numba = import_numba(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)

__all__ = [
  'XLACustomOp',
]


class ShapeDtype(Protocol):

  @property
  def shape(self) -> Tuple[int, ...]:
    ...

  @property
  def dtype(self) -> np.dtype:
    ...


[docs] class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. For more information, please refer to the tutorials above: Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html Args: cpu_kernel: Callable. The function defines the computation on CPU backend. gpu_kernel: Callable. The function defines the computation on GPU backend. batching_translation: Callable. The batching translation rule of JAX. jvp_translation: Callable. The JVP translation rule of JAX. transpose_translation: Callable. The transpose translation rule of JAX. outs: optional. The output information. name: str. The primitive name. """ def __init__( self, cpu_kernel: Callable = None, gpu_kernel: Union[Callable, str] = None, batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, outs: Optional[Callable] = None, name: str = None, ): super().__init__(name) # set cpu_kernel and gpu_kernel self.cpu_kernel = cpu_kernel self.gpu_kernel = gpu_kernel # primitive self.primitive = jax.core.Primitive(self.name) self.primitive.multiple_results = True # abstract evaluation self.outs = outs self.primitive.def_abstract_eval(_abstract_eval) self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function cpu_checked = False if cpu_kernel is None: cpu_checked = True if numba is not None: # numba from numba.core.dispatcher import Dispatcher if isinstance(cpu_kernel, Dispatcher): register_numba_cpu_translation_rule(self.primitive, cpu_kernel) cpu_checked = True if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) cpu_checked = True if not cpu_checked: raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' f'But we got {cpu_kernel}') # gpu function gpu_checked = False if gpu_kernel is None: gpu_checked = True elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: register_general_batching(self.primitive) else: batching.primitive_batchers[self.primitive] = batching_translation # jvp rule if jvp_translation is not None: ad.primitive_jvps[self.primitive] = jvp_translation # transpose rule if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): if outs is None: if self.outs is None: raise ValueError('The output information is not defined.') outs = self.outs(*ins, **kwargs) assert outs is not None outs = tuple([_transform_to_shapedarray(o) for o in outs]) ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) return self.primitive.bind(*ins, outs=outs, **kwargs)
[docs] def def_abstract_eval(self, fun): """Define the abstract evaluation function. Args: fun: The abstract evaluation function. """ self.primitive.def_abstract_eval(fun)
[docs] def def_batching_rule(self, fun): """Define the batching rule. Args: fun: The batching rule. """ batching.primitive_batchers[self.primitive] = fun
[docs] def def_jvp_rule(self, fun): """Define the JVP rule. Args: fun: The JVP rule. """ ad.primitive_jvps[self.primitive] = fun
[docs] def defjvp(self, *jvp_rules): """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results. Args: jvp_rules: The JVP rules. """ defjvp(self.primitive, *jvp_rules)
[docs] def def_transpose_rule(self, fun): """Define the transpose rule. Args: fun: The transpose rule. """ ad.primitive_transposes[self.primitive] = fun
[docs] def def_xla_translation(self, platform, fun): """Define the XLA translation rule. Args: platform: str. The computing platform. fun: The XLA translation rule. """ xla.backend_specific_translations[platform][self.primitive] = fun
[docs] def def_mlir_lowering(self, platform, fun): """Define the MLIR lowering rule. Args: platform: str. The computing platform. fun: The lowering rule. """ mlir.register_lowering(self.primitive, fun, platform)
def _abstract_eval(*args, **kwargs): return [jax.core.ShapedArray(out_shape.shape, out_shape.dtype) for out_shape in kwargs['outs']] def _is_bp_array(a): return isinstance(a, Array) def _transform_to_array(a): if isinstance(a, Array): return a.value elif isinstance(a, jax.Array): return a else: return jax.numpy.asarray(a) def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype)