brainpy.math.operators.register_op
brainpy.math.operators.register_op#
- brainpy.math.operators.register_op(op_name, cpu_func, gpu_func=None, out_shapes=None, apply_cpu_func_to_gpu=False)[source]#
Converting the numba-jitted function in a Jax/XLA compatible primitive.
- Parameters
op_name (str) – Name of the operators.
cpu_func (Callble) – A callable numba-jitted function or pure function (can be lambda function) running on CPU.
gpu_func (Callable, default = None) – A callable cuda-jitted kernel running on GPU.
out_shapes (Callable, ShapedArray, Sequence[ShapedArray], default = None) – Outputs shapes of target function. out_shapes can be a ShapedArray or a sequence of ShapedArray. If it is a function, it takes as input the argument shapes and dtypes and should return correct output shapes of ShapedArray.
apply_cpu_func_to_gpu (bool, default = False) – True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU.
- Return type
A jitable JAX function.