brainpy.math.operators.register_op#

brainpy.math.operators.register_op(op_name, cpu_func, gpu_func=None, out_shapes=None, apply_cpu_func_to_gpu=False)[source]#

Converting the numba-jitted function in a Jax/XLA compatible primitive.

Parameters
  • op_name (str) – Name of the operators.

  • cpu_func (Callble) – A callable numba-jitted function or pure function (can be lambda function) running on CPU.

  • gpu_func (Callable, default = None) – A callable cuda-jitted kernel running on GPU.

  • out_shapes (Callable, ShapedArray, Sequence[ShapedArray], default = None) – Outputs shapes of target function. out_shapes can be a ShapedArray or a sequence of ShapedArray. If it is a function, it takes as input the argument shapes and dtypes and should return correct output shapes of ShapedArray.

  • apply_cpu_func_to_gpu (bool, default = False) – True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU.

Return type

A jitable JAX function.