Low-level Operator Customization#

@Tianqiu Zhang

BrainPy is built on Jax and can accelerate model running performance based on Just-in-Time(JIT) compilation. In order to enhance performance on CPU and GPU, we publish another package BrainPyLib to provide several built-in low-level operators in synaptic computation. These operators are written in C++ and wrapped as Jax primitives by using XLA. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce numba.cfunc here and provide convenient interfaces for users to customize operators without touching the underlying logic.

import brainpy as bp
import brainpy.math as bm
from jax import jit
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray

bm.set_platform('cpu')

In Computation with Sparse Connections section, we formally discuss the benefits of computation with our built-in operators. These operators are provided by brainpylib package and can be accessed through brainpy.math module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++ and converted into Jax/XLA compatible primitive by using Pybind11.

It is not easy to write a C++ operator and implement a series of conversion. Users have to learn how to write a C++ operator, how to write a customized Jax primitive, and how to convert your C++ operator into a Jax primitive. Here are some links for users who prefer to dive into the details: Jax primitives, XLA custom calls.

However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface register_op to register customized operators on CPU and GPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of numba.cfunc, which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Parameters and return types of register_op is listed in this api docs. Here is an example of using register_op on CPU.

How to customize operators?#

CPU version#

First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up numba documentations for details.

def custom_op(outs, ins):
  y, y1 = outs
  x, x2 = ins
  y[:] = x + 1
  y1[:] = x2 + 2

There are some restrictions that users should know:

  • Parameters of the operators are outs and ins, corresponding to output variable(s) and input variable(s). The order cannot be changed.

  • The function cannot have any return value.

  • Notice that in GPU version users should write kernel function according to numba cuda.jit documentation. When applying CPU function to GPU, users only need to implement CPU operators.

Then users should describe the shapes and types of the outputs, because jax/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:

  • a ShapedArray,

  • a sequence of ShapedArray,

  • a function, it should return correct output shapes of ShapedArray.

Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible.

def abs_eval_1(*ins):
  # ins: inputs arguments, only shapes and types are accessible.
  # Because custom_op outputs shapes and types are exactly the
  # same as inputs, so here we can only return ordinary inputs.
  return ins

The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know abs_eval_1 and abs_eval_2 are doing the same thing.

def abs_eval_2(*ins):
  return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)

Now we have prepared for registering a CPU operator. register_op will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:

  • op_name: Name of the operator.

  • cpu_func: Customized operator of CPU version.

  • out_shapes: The shapes and types of the outputs.

z = jnp.ones((1, 2), dtype=jnp.float32)
# Users could try out_shapes=abs_eval_2 and see if the result is different
op = bm.register_op(
  op_name='add',
  cpu_func=custom_op,
  out_shapes=abs_eval_1,
  apply_cpu_func_to_gpu=False)
jit_op = jit(op)
print(jit_op(z, z))
[DeviceArray([[2., 2.]], dtype=float32), DeviceArray([[3., 3.]], dtype=float32)]

GPU version#

We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:

  • gpu_func: Customized operator of CPU version.

  • apply_cpu_func_to_gpu: Whether to run kernel function on CPU for an alternative way for GPU version.

Warning

GPU operators will be wrapped by cuda.jit in numba, but numba currently is not support to launch CUDA kernels from cfuncs. For this reason, gpu_func is none for default, and there will be an error if users pass a gpu operator to gpu_func.

Therefore, BrainPy enables users to set apply_cpu_func_to_gpu to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks.

Performance#

To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use event_sum as an example. The implementation of event_sum by using our customization is shown as below:

def abs_eval(events, indices, indptr, post_size, values):
  return post_size


def event_sum_op(outs, ins):
  post_val = outs
  events, indices, indptr, post_size, values = ins

  for i in range(len(events)):
      if events[i]:
        for j in range(indptr[i], indptr[i+1]):
          index = indices[j]
          old_value = post_val[index]
          post_val[index] = values + old_value


event_sum = bm.register_op(op_name='event_sum', cpu_func=event_sum_op, out_shapes=abs_eval)
jit_event_sum = jit(event_sum)

Exponential COBA will be our benchmark for testing the speed. We will use built-in operator event_sum first.

class ExpCOBA(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
               method='exp_auto'):
    super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def update(self, _t, _dt):
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # Built-in operator
    # --------------------------------------------------------------------------------------
    self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
    # --------------------------------------------------------------------------------------
    self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.dyn.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = bp.models.LIF(num_exc, **pars, method=method)
    I = bp.models.LIF(num_inh, **pars, method=method)
    E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
    I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
    I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)


net = EINet(scale=10., method='euler')
# simulation
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.628559827804565

The total time is 15.62 seconds. Next we use our customized operator.

class ExpCOBA(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
               method='exp_auto'):
    super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def update(self, _t, _dt):
    self.g.value = self.integral(self.g, _t, dt=_dt)
    post_size = bm.zeros(self.post.num)
    # Customized operator
    # ------------------------------------------------------------------------------------------------------------
    self.g += jit_event_sum(self.pre.spike, self.pre2post[0].value, self.pre2post[1].value, post_size, self.g_max)
    # ------------------------------------------------------------------------------------------------------------
    self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.dyn.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = bp.models.LIF(num_exc, **pars, method=method)
    I = bp.models.LIF(num_inh, **pars, method=method)
    E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
    I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
    I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)


net = EINet(scale=10., method='euler')
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.703513145446777

After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss.