CPU Operator Customization with Numba#

Colab Open in Kaggle

English version#

Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using Numba.

Start by importing the relevant Python package.

import brainpy as bp
import brainpy.math as bm

import jax
from jax import jit
import jax.numpy as jnp
from jax.core import ShapedArray

import numba

bm.set_platform('cpu')

brainpy.math.CustomOpByNumba#

BrainPy provides brainpy.math.CustomOpByNumba for customizing the operator on the CPU device. Two parameters are required to provide in CustomOpByNumba:

  • eval_shape: evaluates the shape and datatype of the output argument based on the shape and datatype of the input argument.

  • con_compute: receives the input parameters and performs a specific computation based on them.

Suppose here we want to customize an operator that does the b = a+1 operation. First, define an eval_shape function. The arguments to this function are information about all the input parameters, and the return value is information about the output parameters.

from jax.core import ShapedArray

def eval_shape(a):
  b = ShapedArray(a.shape, dtype=a.dtype)
  return b

Since b in b = a + 1 has the same type and shape as a, the eval_shape function returns the same shape and type. Next, we need to define con_compute. con_compute takes only (outs, ins) arguments, where all return values are inside outs and all input arguments are inside ins.

def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

Unlike the eval_shape function, the con_compute function does not support any return values. Instead, all output must just be updated in-place. Also, the con_compute function must follow the specification of Numba’s just-in-time compilation, see:

  • https://numba.pydata.org/numba-doc/latest/reference/pysupported.html

  • https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

Also, con_compute can be customized according to Numba’s just-in-time compilation policy. For example, if JIT is just turned on, then you can use:

@numba.njit
def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

If the parallel computation with multiple cores is turned on, you can use:

@numba.njit(parallel=True)
def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

For more advanced usage, we encourage readers to read the Numba online manual.

Finally, this customized operator can be registered and used as:

>>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
>>> op(bm.zeros(10))
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

Return multiple values multiple_returns=True#

If the result of our computation needs to return multiple arrays, then we need to use multiple_returns=True in our use of registering the operator. In this case, outs will be a list containing multiple arrays, not an array.

def eval_shape2(a, b):
  c = ShapedArray(a.shape, dtype=a.dtype)
  d = ShapedArray(b.shape, dtype=b.dtype)
  return c, d

def con_compute2(outs, ins):
  c = outs[0]  # take out all the outputs
  d = outs[1]
  a = ins[0]  # take out all the inputs
  b = ins[1]
  c[:] = a + 1
  d[:] = a * 2

op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
>>> op2(bm.zeros(10), bm.ones(10))
([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])

Non-Tracer parameters#

In the eval_shape function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by jax.jit. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.

For an operator defined by brainpy.math.CustomOpByNumba, non-Tracer parameters are often then parameters passed in via key-value pairs such as key=value. For example:

op2(a, b, c, d=d, e=e)

a, b, c are all jax.jit traceable parameters, and d and e are deterministic, non-tracer parameters. Therefore, in the eval_shape(a, b, c, d, e) function, a, b, c will be SharedArray, and d and e will be concrete values.

For another example,


def eval_shape3(a, *, b):
  return SharedArray(b, a.dtype)  # The shape of the return value is determined by the input b

def con_compute3(outs, ins):
  c = outs  # Take out all the outputs
  a = ins[0] # Take out all inputs
  b = ins[1]
  c[:] = 2.

op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)
>>> op3(bm.zeros(4), 5)
[2. 2. 2. 2. 2.]

… note:

It is worth noting that all arguments will be converted to arrays. Both Tracer and non-Tracer parameters are arrays in ``con_compute``. For example, ``1`` is passed in, but in ``con_compute`` it's a 0-dimensional array ``1``; ``(1, 2)`` is passed in, and in ``con_compute`` it will be the 1-dimensional array ``array([1, 2])``.

Example: A sparse operator#

To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator.

def abs_eval(data, indices, indptr, vector, shape):
  out_shape = shape[0]
  return ShapedArray((out_shape,), data.dtype),

@numba.njit(fastmath=True)
def sparse_op(outs, ins):
  res_val = outs[0]
  res_val.fill(0)
  values, col_indices, row_ptr, vector, shape = ins

  for row_i in range(shape[0]):
      v = vector[row_i]
      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
          res_val[col_indices[j]] += values * v

sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)

Let’s try to use sparse matrix vector multiplication operator.

size = 5000

vector = bm.random.randn(size)
sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')
f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))
f(1.)
[Array([ -2.2834747, -52.950108 ,  -5.0921535, ..., -40.264236 ,
        -27.219269 ,  33.138054 ], dtype=float32)]

brainpy.math.XLACustomOp#

brainpy.math.XLACustomOp is a new method for customizing operators on the CPU device. It is similar to brainpy.math.CustomOpByNumba, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to brainpy.math.XLACustomOp.

Detailed steps are as follows:

Define the kernel#

@numba.njit(fastmath=True)
def numba_event_csrmv(weight, indices, vector, outs):
  outs.fill(0)
  weight = weight[()]  # 0d
  for row_i in range(vector.shape[0]):
    if vector[row_i]:
      for j in indices[row_i]:
        outs[j] += weight

In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: weight, indices, vector, and outs. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively.

Registering and Using Custom Operators#

After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.

Note: Maintain the order of the operator’s declared parameters consistent with the order when calling.

prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)
indices = bm.random.randint(0, s, (s, 80))
vector = bm.random.rand(s) < 0.1
out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
print(out)

中文版#

大脑动力学具有稀疏和事件驱动的特性,然而,大脑动力学的专有算子并没有很好的抽象和总结。因此,我们往往面临着自定义算子的需求。在这个教程中,我们将探索如何使用Numba来自定义脑动力学算子。

首先引入相关的Python包。

import brainpy as bp
import brainpy.math as bm

import jax
from jax import jit
import jax.numpy as jnp
from jax.core import ShapedArray

import numba

bm.set_platform('cpu')

brainpy.math.CustomOpByNumba接口#

brainpy.math.CustomOpByNumba 也叫做brainpy.math.XLACustomOp

BrainPy提供了brainpy.math.CustomOpByNumba用于自定义CPU上的算子。使用CustomOpByNumba需要提供两个接口:

  • eval_shape: 根据输入参数的形状(shape)和数据类型(dtype)来评估输出参数的形状和数据类型。

  • con_compute: 接收真正的参数,并根据参数进行具体计算。

假如在这里我们要自定义一个做b = a+1操作的算子。首先,定义一个eval_shape函数。该函数的参数是所有输入变量的信息,返回值是输出参数的信息。

from jax.core import ShapedArray

def eval_shape(a):
  b = ShapedArray(a.shape, dtype=a.dtype)
  return b

由于b = a + 1ba具有同样的类型和形状,因此eval_shape函数返回一样的形状和类型。接下来,我们就需要定义con_computecon_compute只接收(outs, ins)参数,其中,所有的返回值都在outs内,所有的输入参数都在ins内。


def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

eval_shape函数不同,con_compute函数不接收任何返回值。相反,所有的输出都必须通过in-place update的形式就行。另外,con_compute函数必须遵循Numba即时编译的规范,见:

  • https://numba.pydata.org/numba-doc/latest/reference/pysupported.html

  • https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

同时,con_compute也可以自定义Numba的即时编译策略。比如,如果只是开启JIT,那么可以用:

@numba.njit
def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

如果是开始并行计算利用多核,可以使用:

@numba.njit(parallel=True)
def con_compute(outs, ins):
  b = outs
  a = ins
  b[:] = a + 1

更多高级用法,建议读者们阅读Numba在线手册

最后,我们自定义这个算子可以使用:

>>> op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
>>> op(bm.zeros(10))
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

返回多个值 multiple_returns=True#

如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用multiple_returns=True。此时,outs将会是一个包含多个数组的列表,而不是一个数组。

def eval_shape2(a, b):
  c = ShapedArray(a.shape, dtype=a.dtype)
  d = ShapedArray(b.shape, dtype=b.dtype)
  return c, d  # 返回多个抽象数组信息

def con_compute2(outs, ins):
  c = outs[0]  # 取出所有的输出
  d = outs[1]
  a = ins[0]  # 取出所有的输入
  b = ins[1]
  c[:] = a + 1
  d[:] = a * 2

op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
>>> op2(bm.zeros(10), bm.ones(10))
([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.],
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.])

非Tracer参数#

eval_shape函数中推断数据类型时,如果所有参数都是可以被jax.jit追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。

对于一个由brainpy.math.CustomOpByNumba定义的算子,非Tracer参数往往那么通过key=value等键值对传入的参数。比如,

op2(a, b, c, d=d, e=e)

a, b, c都是可被jax.jit追踪的参数,de是确定性的、非Tracer参数。此时,eval_shape(a, b, c, d, e)函数中,a,b,c都是SharedArray,而d和e都是具体的数值,

举个例子,


def eval_shape3(a, *, b):
  return SharedArray(b, a.dtype)  # 返回值的形状由输入b决定

def con_compute3(outs, ins):
  c = outs  # 取出所有的输出
  a = ins[0] # 取出所有的输入
  b = ins[1]
  c[:] = 2.

op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)
>>> op3(bm.zeros(4), 5)
[2. 2. 2. 2. 2.]

… note::

值得注意的是,所有的输入值都将被转化成数组。无论是Tracer还是非Tracer参数,在``con_compute``中都是数组。比如传入的是``1``,但在``con_compute``中是0维数组``1``;传入的是``(1, 2)``,在``con_compute``中将是1维数组``array([1, 2])``。

示例:一个稀疏算子#

为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。

def abs_eval(data, indices, indptr, vector, shape):
  out_shape = shape[0]
  return [ShapedArray((out_shape,), data.dtype)]

@numba.njit(fastmath=True)
def sparse_op(outs, ins):
  res_val = outs[0]
  res_val.fill(0)
  values, col_indices, row_ptr, vector, shape = ins

  for row_i in range(shape[0]):
      v = vector[row_i]
      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
          res_val[col_indices[j]] += values * v

sparse_cus_op = bm.CustomOpByNumba(eval_shape=abs_eval, con_compute=sparse_op)

使用该算子我们可以用:

size = 5000

vector = bm.random.randn(size)
sparse_A = bp.conn.FixedProb(prob=0.1, allow_multi_conn=True)(size, size).require('pre2post')
f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))
f(1.)
[Array([ 17.464092,  -9.924386, -33.09052 , ..., -37.2057  , -12.551924,
         -9.046049], dtype=float32)]

brainpy.math.XLACustomOp#

brainpy.math.XLACustomOp is a new method for customizing operators on the CPU device. It is similar to brainpy.math.CustomOpByNumba, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit decorator, and then pass the kernel to brainpy.math.XLACustomOp. brainpy.math.XLACustomOp是一种自定义算子的新方法。它类似于brainpy.math.CustomOpByNumba,但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法,只需要使用 @numba.jit@numba.njit装饰器定义一个kernel,然后将内核传递给brainpy.math.XLACustomOp

详细步骤如下:

定义kernel#

在参数声明中,最后几个参数需要是输出参数,这样numba才能正确编译。这个算子numba_event_csrmv接受四个参数:weight、indices、vector 和 outs。前三个参数是输入参数,最后一个参数是输出参数。输出参数是一个一维数组,输入参数分别是 0D、1D 和 2D 数组。

@numba.njit(fastmath=True)
def numba_event_csrmv(weight, indices, vector, outs):
  outs.fill(0)
  weight = weight[()]  # 0d
  for row_i in range(vector.shape[0]):
    if vector[row_i]:
      for j in indices[row_i]:
        outs[j] += weight

注册并使用自定义算子#

在定义了自定义算子之后,可以将其注册到特定框架中,并在需要的地方使用它。在注册时可以指定cpu_kernelgpu_kernel,这样算子就可以在不同的设备上运行。并在调用中指定outs参数,用jax.ShapeDtypeStruct来指定输出的形状和数据类型。

注意: 在算子声明的参数与调用时需要保持顺序的一致。

prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)
indices = bm.random.randint(0, s, (s, 80))
vector = bm.random.rand(s) < 0.1
out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])
print(out)