GPU Operator Customization with CuPy#

Colab Open in Kaggle

This functionality is only available for brainpylib>=0.3.1.

English Version#

Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy’s

to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.

Start by importing the relevant Python package.

import brainpy.math as bm

import jax
import cupy as cp
from cupyx import jit


CuPy RawModule#

For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.

Be aware that the order of parameters in the kernel function you want to call should keep outputs at the end of the parameter list.

source_code = r'''
    extern "C"{

    __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)
        unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
        if (tid < N)
            y[tid] = x1[tid] + x2[tid];
mod = cp.RawModule(code=source_code)
kernel = mod.get_function('kernel')

After define the RawModule and get the kernel function. You can use bm.XLACustomOp to register it into it’s gpu_kernel and call it with the appropriate gird and block you want (Here these two parameters both should be Tuple).

Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.

# prepare inputs
N = 10
x1 = bm.ones((N, N))
x2 = bm.ones((N, N))

# register the kernel as a custom op
prim1 = bm.XLACustomOp(gpu_kernel=kernel)

# call the custom op
y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]

CuPy JIT RawKernel#

The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.

In this section, a Python function wrapped with the decorator is called a target function.

Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:

Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls.

def elementwise_copy(x, size, y):
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        y[i] = x[i]

After define the jit.rawkernel. You can use bm.XLACustomOp to register it into it’s gpu_kernel and call it with the appropriate gird and block you want (Here these two parameters both should be Tuple).

# prepare inputs
size = 100
x = bm.ones((size,))

# register the kernel as a custom op
prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)

# call the custom op
y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]



来直接作为字符串或cupy JIT函数实时编译并运行CUDA原生代码,以实现更细致的控制。


import brainpy.math as bm

import jax
import cupy as cp
from cupyx import jit


CuPy RawModule#



source_code = '''
    extern "C"{

    __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)
        unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
        if (tid < N)
            y[tid] = x1[tid] + x2[tid];
mod = cp.RawModule(code=source_code)
kernel = mod.get_function('kernel')



# 准备输入
N = 10
x1 = bm.ones((N, N))
x2 = bm.ones((N, N))

# 将kernel注册为自定义算子
prim1 = bm.XLACustomOp(gpu_kernel=kernel)

# 调用自定义算子
y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]

CuPy JIT RawKernel#




def elementwise_copy(x, size, y):
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        y[i] = x[i]


# 准备输入
size = 100
x = bm.ones((size,))

# 将kernel注册为自定义算子
prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)

# 调用自定义算子
y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]