CPU Operator Customization with Numba#

English version#

Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using Numba.

Start by importing the relevant Python package.

import brainpy as bp
import brainpy.math as bm

import jax
from jax import jit
import jax.numpy as jnp
from jax.core import ShapedArray

import numba

bm.set_platform('cpu')


brainpy.math.CustomOpByNumba#

BrainPy provides brainpy.math.CustomOpByNumba for customizing the operator on the CPU device. Two parameters are required to provide in CustomOpByNumba:

• eval_shape: evaluates the shape and datatype of the output argument based on the shape and datatype of the input argument.

• con_compute: receives the input parameters and performs a specific computation based on them.

Suppose here we want to customize an operator that does the b = a+1 operation. First, define an eval_shape function. The arguments to this function are information about all the input parameters, and the return value is information about the output parameters.

from jax.core import ShapedArray

def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b


Since b in b = a + 1 has the same type and shape as a, the eval_shape function returns the same shape and type. Next, we need to define con_compute. con_compute takes only (outs, ins) arguments, where all return values are inside outs and all input arguments are inside ins.

def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


Unlike the eval_shape function, the con_compute function does not support any return values. Instead, all output must just be updated in-place. Also, the con_compute function must follow the specification of Numba’s just-in-time compilation, see:

• https://numba.pydata.org/numba-doc/latest/reference/pysupported.html

• https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

Also, con_compute can be customized according to Numba’s just-in-time compilation policy. For example, if JIT is just turned on, then you can use:

@numba.njit
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


If the parallel computation with multiple cores is turned on, you can use:

@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


Finally, this customized operator can be registered and used as:

>>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
>>> op(bm.zeros(10))
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


Return multiple values multiple_returns=True#

If the result of our computation needs to return multiple arrays, then we need to use multiple_returns=True in our use of registering the operator. In this case, outs will be a list containing multiple arrays, not an array.

def eval_shape2(a, b):
c = ShapedArray(a.shape, dtype=a.dtype)
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d

def con_compute2(outs, ins):
c = outs[0]  # take out all the outputs
d = outs[1]
a = ins[0]  # take out all the inputs
b = ins[1]
c[:] = a + 1
d[:] = a * 2

op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)

>>> op2(bm.zeros(10), bm.ones(10))
([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])


Non-Tracer parameters#

In the eval_shape function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by jax.jit. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.

For an operator defined by brainpy.math.CustomOpByNumba, non-Tracer parameters are often then parameters passed in via key-value pairs such as key=value. For example:

op2(a, b, c, d=d, e=e)


a, b, c are all jax.jit traceable parameters, and d and e are deterministic, non-tracer parameters. Therefore, in the eval_shape(a, b, c, d, e) function, a, b, c will be SharedArray, and d and e will be concrete values.

For another example,


def eval_shape3(a, *, b):
return SharedArray(b, a.dtype)  # The shape of the return value is determined by the input b

def con_compute3(outs, ins):
c = outs  # Take out all the outputs
a = ins[0] # Take out all inputs
b = ins[1]
c[:] = 2.

op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)

>>> op3(bm.zeros(4), 5)
[2. 2. 2. 2. 2.]


… note:

It is worth noting that all arguments will be converted to arrays. Both Tracer and non-Tracer parameters are arrays in con_compute. For example, 1 is passed in, but in con_compute it's a 0-dimensional array 1; (1, 2) is passed in, and in con_compute it will be the 1-dimensional array array([1, 2]).


Example: A sparse operator#

To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator.

def abs_eval(data, indices, indptr, vector, shape):
out_shape = shape[0]
return ShapedArray((out_shape,), data.dtype),

@numba.njit(fastmath=True)
def sparse_op(outs, ins):
res_val = outs[0]
res_val.fill(0)
values, col_indices, row_ptr, vector, shape = ins

for row_i in range(shape[0]):
v = vector[row_i]
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
res_val[col_indices[j]] += values * v

sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)


Let’s try to use sparse matrix vector multiplication operator.

size = 5000

vector = bm.random.randn(size)
sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')
f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))
f(1.)

[Array([ -2.2834747, -52.950108 ,  -5.0921535, ..., -40.264236 ,
-27.219269 ,  33.138054 ], dtype=float32)]


brainpy.math.XLACustomOp#

brainpy.math.XLACustomOp is a new method for customizing operators on the CPU device. It is similar to brainpy.math.CustomOpByNumba, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to brainpy.math.XLACustomOp.

Detailed steps are as follows:

Define the kernel#

@numba.njit(fastmath=True)
def numba_event_csrmv(weight, indices, vector, outs):
outs.fill(0)
weight = weight[()]  # 0d
for row_i in range(vector.shape[0]):
if vector[row_i]:
for j in indices[row_i]:
outs[j] += weight


In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: weight, indices, vector, and outs. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively.

Registering and Using Custom Operators#

After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.

Note: Maintain the order of the operator’s declared parameters consistent with the order when calling.

prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)
indices = bm.random.randint(0, s, (s, 80))
vector = bm.random.rand(s) < 0.1
out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
print(out)


中文版#

import brainpy as bp
import brainpy.math as bm

import jax
from jax import jit
import jax.numpy as jnp
from jax.core import ShapedArray

import numba

bm.set_platform('cpu')


brainpy.math.CustomOpByNumba接口#

brainpy.math.CustomOpByNumba 也叫做brainpy.math.XLACustomOp

BrainPy提供了brainpy.math.CustomOpByNumba用于自定义CPU上的算子。使用CustomOpByNumba需要提供两个接口：

• eval_shape: 根据输入参数的形状(shape)和数据类型(dtype)来评估输出参数的形状和数据类型。

• con_compute: 接收真正的参数，并根据参数进行具体计算。

from jax.core import ShapedArray

def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b



def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


eval_shape函数不同，con_compute函数不接收任何返回值。相反，所有的输出都必须通过in-place update的形式就行。另外，con_compute函数必须遵循Numba即时编译的规范，见：

• https://numba.pydata.org/numba-doc/latest/reference/pysupported.html

• https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

@numba.njit
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1


>>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
>>> op(bm.zeros(10))
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


返回多个值 multiple_returns=True#

def eval_shape2(a, b):
c = ShapedArray(a.shape, dtype=a.dtype)
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d  # 返回多个抽象数组信息

def con_compute2(outs, ins):
c = outs[0]  # 取出所有的输出
d = outs[1]
a = ins[0]  # 取出所有的输入
b = ins[1]
c[:] = a + 1
d[:] = a * 2

op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)

>>> op2(bm.zeros(10), bm.ones(10))
([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])


非Tracer参数#

eval_shape函数中推断数据类型时，如果所有参数都是可以被jax.jit追踪的参数，那么所有参数都是抽象信息（只包含形状和类型）。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息，此时我们需要定义非Tracer参数。

op2(a, b, c, d=d, e=e)


a, b, c都是可被jax.jit追踪的参数，de是确定性的、非Tracer参数。此时，eval_shape(a, b, c, d, e)函数中，a，b，c都是SharedArray，而d和e都是具体的数值，


def eval_shape3(a, *, b):
return SharedArray(b, a.dtype)  # 返回值的形状由输入b决定

def con_compute3(outs, ins):
c = outs  # 取出所有的输出
a = ins[0] # 取出所有的输入
b = ins[1]
c[:] = 2.

op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)

>>> op3(bm.zeros(4), 5)
[2. 2. 2. 2. 2.]


… note::

值得注意的是，所有的输入值都将被转化成数组。无论是Tracer还是非Tracer参数，在con_compute中都是数组。比如传入的是1，但在con_compute中是0维数组1；传入的是(1, 2)，在con_compute中将是1维数组array([1, 2])。


示例：一个稀疏算子#

def abs_eval(data, indices, indptr, vector, shape):
out_shape = shape[0]
return [ShapedArray((out_shape,), data.dtype)]

@numba.njit(fastmath=True)
def sparse_op(outs, ins):
res_val = outs[0]
res_val.fill(0)
values, col_indices, row_ptr, vector, shape = ins

for row_i in range(shape[0]):
v = vector[row_i]
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
res_val[col_indices[j]] += values * v

sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)


size = 5000

vector = bm.random.randn(size)
sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')
f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))
f(1.)

[Array([ 17.464092,  -9.924386, -33.09052 , ..., -37.2057  , -12.551924,
-9.046049], dtype=float32)]


brainpy.math.XLACustomOp#

brainpy.math.XLACustomOp is a new method for customizing operators on the CPU device. It is similar to brainpy.math.CustomOpByNumba, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit decorator, and then pass the kernel to brainpy.math.XLACustomOp. brainpy.math.XLACustomOp是一种自定义算子的新方法。它类似于brainpy.math.CustomOpByNumba，但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法，只需要使用 @numba.jit@numba.njit装饰器定义一个kernel，然后将内核传递给brainpy.math.XLACustomOp

定义kernel#

@numba.njit(fastmath=True)
def numba_event_csrmv(weight, indices, vector, outs):
outs.fill(0)
weight = weight[()]  # 0d
for row_i in range(vector.shape[0]):
if vector[row_i]:
for j in indices[row_i]:
outs[j] += weight


注册并使用自定义算子#

prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)
indices = bm.random.randint(0, s, (s, 80))
vector = bm.random.rand(s) < 0.1
out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
print(out)