CPU and GPU Operator Customization with Taichi#

Colab Open in Kaggle

This functionality is only available for brainpylib>=0.2.0.

English version#

Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using taichi.

Start by importing the relevant Python package.

import brainpy.math as bm

import jax
import jax.numpy as jnp
import pytest
import platform

import taichi as ti


Basic Structure of Custom Operators#

Taichi uses Python functions and decorators to define custom operators. Here is a basic structure of a custom operator:

def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):
    # Internal logic of the operator

The @ti.kernel decorator tells Taichi that this is a function that requires special compilation.

Defining Helper Functions#

When defining complex custom operators, you can use the @ti.func decorator to define helper functions. These functions can be called inside the kernel function:

def helper_func(x: ti.f32) -> ti.f32:
    # Auxiliary computation
    return x * 2

def my_kernel(arg: ti.types.ndarray()):
    for i in ti.ndrange(arg.shape[0]):
        arg[i] *= helper_func(arg[i])

Example: Custom Event Processing Operator#

The following example demonstrates how to customize an event processing operator:

def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:
    return weight[None]

def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
    out[index] += weight_val

def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1),
                  weight: ti.types.ndarray(ndim=0),
                  out: ti.types.ndarray(ndim=1)):
    weight_val = get_weight(weight)
    num_rows, num_cols = indices.shape
    for i in range(num_rows):
        if vector[i]:
            for j in range(num_cols):
                update_output(out, indices[i, j], weight_val)

In the declaration of parameters, the last few parameters need to be output parameters so that Taichi can compile correctly. This operator event_ell_cpu receives indices, vectors, weights, and output arrays, and updates the output arrays according to the provided logic.

Registering and Using Custom Operators#

After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.

Note: Maintain the order of the operator’s declared parameters consistent with the order when calling.

import brainpy.math as bm

# Taichi operator registration
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)

# Using the operator
def test_taichi_op():
    # Create input data
    # ...

    # Call the custom operator
    out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

    # Output the result

Taichi Optimization Methods#

For Loop Decorators#

Taichi kernels automatically parallelize for-loops in the outermost scope. Our compiler sets the settings automatically to best explore the target architecture. Nonetheless, for Ninjas seeking the final few percent of speed, we provide several APIs to allow developers to fine-tune their programs. Specifying a proper block_dim is key.

You can use ti.loop_config to set the loop directives for the next for loop. Available directives are:

  • parallelize: Sets the number of threads to use on CPU

  • block_dim: Sets the number of threads in a block on GPU

  • serialize: If you set serialize to True, the for loop will run serially, and you can write break statements inside it (Only applies on range/ndrange fors). Equals to setting parallelize to 1.

def break_in_serial_for() -> ti.i32:
    a = 0
    for i in range(100):  # This loop runs serially
        a += i
        if i == 10:
    return a

break_in_serial_for()  # returns 55
n = 128
val = ti.field(ti.i32, shape=n)
def fill():
    ti.loop_config(parallelize=8, block_dim=16)
    # If the kernel is run on the CPU backend, 8 threads will be used to run it
    # If the kernel is run on the CUDA backend, each block will have 16 threads.
    for i in range(n):
        val[i] = i

Complete example#

Here is a complete example showing how to implement a simple operator using the taichi custom operator:

import jax
import jax.numpy as jnp
import taichi as ti
import pytest
import platform

import brainpy.math as bm


def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:
  return weight[None]

def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
  out[index] += weight_val

def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1),
                  weight: ti.types.ndarray(ndim=0),
                  out: ti.types.ndarray(ndim=1)):
  weight_val = get_weight(weight)
  num_rows, num_cols = indices.shape
  for i in range(num_rows):
    if vector[i]:
      for j in range(num_cols):
        update_output(out, indices[i, j], weight_val)

def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1), 
                  weight: ti.types.ndarray(ndim=0), 
                  out: ti.types.ndarray(ndim=1)):
  weight_val = get_weight(weight)
  num_rows, num_cols = indices.shape
  for i in range(num_rows):
    if vector[i]:
      for j in range(num_cols):
        update_output(out, indices[i, j], weight_val)

prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)

def test_taichi_op_register():
  s = 1000
  indices = bm.random.randint(0, s, (s, 1000))
  vector = bm.random.rand(s) < 0.1

  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])



More Examples#

For more examples, please refer to:

Clean the cache of taichi kernels#

Because brainpy fuse taichi and JAX using taichi AOT method, the taichi kernels will be cached in the system. If you want to clean the cache, you can use the following code:

import brainpy.math as bm





import brainpy.math as bm

import jax
import jax.numpy as jnp
import pytest
import platform

import taichi as ti



taichi 使用 Python 函数和装饰器来定义自定义算子。以下是一个基本的自定义算子结构:

def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):
    # 算子内部的计算逻辑

其中,@ti.kernel 装饰器用于告诉 Taichi 这是一个需要特殊编译的函数。


在定义复杂的自定义算子时,可以使用 @ti.func 装饰器定义辅助函数。这些函数可以在 kernel 函数内部调用:

def helper_func(x: ti.f32) -> ti.f32:
    # 辅助计算
    return x * 2

def my_kernel(arg: ti.types.ndarray()):
    for i in ti.ndrange(arg.shape[0]):
        arg[i] *= helper_func(arg[i])



def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
    return weight[0]

def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
    out[index] += weight_val

def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1),
                  weight: ti.types.ndarray(ndim=1),
                  out: ti.types.ndarray(ndim=1)):
    weight_val = get_weight(weight)
    num_rows, num_cols = indices.shape
    for i in range(num_rows):
        if vector[i]:
            for j in range(num_cols):
                update_output(out, indices[i, j], weight_val)

在参数的声明上,需要最后的几个参数是输出参数,这样 Taichi 才能正确的编译。这个算子 event_ell_cpu 接收索引、向量、权重和输出数组,并根据提供的逻辑更新输出数组。



注意: 在算子声明的参数与调用时需要保持顺序的一致。

import brainpy.math as bm

# Taichi 算子注册
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)

# 算子使用
def test_taichi_op():
    # 创建输入数据
    # ...

    # 调用自定义算子
    out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

    # 输出结果



Taichi 内核会自动并行化最外层作用域中的 for 循环。我们的编译器会自动设置配置,以最佳方式探索目标架构。然而,对于追求最后几个百分点速度的高手,我们提供了几个 API 来允许开发者精细调整他们的程序。指定合适的 block_dim 是关键。

你可以使用 ti.loop_config 来设置下一个 for 循环的循环指令。可用的指令有:

  • parallelize:在 CPU 上使用的线程数

  • block_dim:在 GPU 上一个块中的线程数

  • serialize:如果你将 serialize 设置为 True,for 循环将会串行执行,你可以在其中编写 break 语句(仅适用于 range/ndrange 循环)。等同于将 parallelize 设置为 1。

def break_in_serial_for() -> ti.i32:
    a = 0
    for i in range(100):  # This loop runs serially
        a += i
        if i == 10:
    return a

break_in_serial_for()  # returns 55
n = 128
val = ti.field(ti.i32, shape=n)
def fill():
    ti.loop_config(parallelize=8, block_dim=16)
    # If the kernel is run on the CPU backend, 8 threads will be used to run it
    # If the kernel is run on the CUDA backend, each block will have 16 threads.
    for i in range(n):
        val[i] = i


下面是一个完整的示例,展示了如何使用 taichi 自定义算子来实现一个简单的算子:

import jax
import jax.numpy as jnp
import taichi as ti
import pytest
import platform

import brainpy.math as bm


def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
  return weight[0]

def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
  out[index] += weight_val

def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1),
                  weight: ti.types.ndarray(ndim=1),
                  out: ti.types.ndarray(ndim=1)):
  weight_val = get_weight(weight)
  num_rows, num_cols = indices.shape
  for i in range(num_rows):
    if vector[i]:
      for j in range(num_cols):
        update_output(out, indices[i, j], weight_val)

def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
                  vector: ti.types.ndarray(ndim=1), 
                  weight: ti.types.ndarray(ndim=1), 
                  out: ti.types.ndarray(ndim=1)):
  weight_0 = weight[0]
  for ij in ti.grouped(indices):
      if vector[ij[0]]:
          out[ij[1]] += weight_0

prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)

def test_taichi_op_register():
  s = 1000
  indices = bm.random.randint(0, s, (s, 1000))
  vector = bm.random.rand(s) < 0.1
  weight = bm.array([1.0])

  out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

  out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])




对于更多示例, 请参考:

清除Taichi kernel的缓存#


import brainpy.math as bm
