Low-level Operator Customization#

@Tianqiu Zhang @Chaoming Wang

In Computation with Sparse Connections section, we formally discuss the benefits of computation with our built-in operators. These operators are provided by brainpylib package and can be accessed through brainpy.math module. However, these low-level operators for CPU and GPU devices are written in C++ and CUDA. It is not easy to write a C++ operator and implement a series of conversion. Users have to learn how to write a C++ operator, how to write a customized JAX primitive, and how to convert your C++ operator into a JAX primitive. Here are some links for users who prefer to dive into the details: JAX primitives, XLA custom calls.

However, it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides convenient interfaces brainpy.math.register_op() and brainpy.math.XLACustomOp to register customized operators on CPU and GPU devices with Python syntax. Users no longer need to involve any C++ programming and XLA compilation.

import brainpy as bp
import brainpy.math as bm

import jax
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray

bm.set_platform('cpu')

Customize a CPU operator#

The customization of CPU operator is accomplished with the help of numba.cfunc, which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable JAX primitives. Parameters and return types of register_op is listed in this api docs.

In general, the customization of a CPU operator needs to provide two function:

  • abstract evaluation function: specifies the abstract shape and dtype of the output according to the input abstract information. This information is used because it can help JAX to infer the shapes and types of the outputs. This abstract evaluation function can be provided as

    • a ShapedArray, like

      ShapedArray(10, jnp.float32)
      
    • a sequence of ShapedArray, like

      [ShapedArray(10, jnp.float32), ShapedArray(1, jnp.int32)]
      
    • a function, it should return correct output shapes of ShapedArray, like

      def abs_eval(inp1, inp2):
        return (ShapedArray(inp1.shape, inp1.dtype),
                ShapedArray(inp2.shape, inp2.dtype))
      
  • concreate computation function: specifies how the output data are computed according to the input data.

Here is an example of operator customization on CPU device.

# What we want to do is a simple add operation.
# Therefore, the shape and dtype of outputs are
# the same with those of inputs.

def abs_eval(*ins):
  # ins: inputs arguments, only shapes and types are accessible.
  # Because custom_op outputs shapes and types are exactly the
  # same as inputs, so here we can only return ordinary inputs.
  return ins
# Note here the concreate computation function only supports
# to receive two arguments "outs" and "ins", and does not
# support value return.

def con_compute(outs, ins):
  y, y1 = outs
  x, x2 = ins
  y[:] = x + 1
  y1[:] = x2 + 2

There are some restrictions for concreate computation function that users should know:

  • Parameters of the operators are outs and ins, corresponding to output variable(s) and input variable(s). The order cannot be changed.

  • The function cannot have any return value.

Now we have prepared for registering a CPU operator. register_op or XLACustomOp will be called to wrap your operator and return a jittable JAX primitives. Here are some parameters users should define:

  • name: Name of the operator.

  • cpu_func: Customized operator of CPU version.

  • eval_shape: The shapes and types of the outputs.

op = bm.register_op(name='add',
                    cpu_func=con_compute,
                    eval_shape=abs_eval)
class AddOp(bm.XLACustomOp):
  def __init__(self, name):

    def abs_eval(*ins):
      return ins

    def con_compute(outs, ins):
      y, y1 = outs
      x, x2 = ins
      y[:] = x + 1
      y1[:] = x2 + 2

    super(AddOp, self).__init__(name=name, cpu_func=con_compute, eval_shape=abs_eval)

op2 = AddOp('add')

Let’s try to use this operator.

z = jnp.ones((1, 2), dtype=jnp.float32)

jax.jit(op)(z, z)
[DeviceArray([[2., 2.]], dtype=float32),
 DeviceArray([[3., 3.]], dtype=float32)]
jax.jit(op2)(z, z)
[DeviceArray([[2., 2.]], dtype=float32),
 DeviceArray([[3., 3.]], dtype=float32)]

Note

Actually, the concreate computation function should be a function compatitable with the nonpython mode of numba.jit(). Users should refer to Numba’s documentation to check how to write a function which can be jitted by Numba. Fortunately, Numba’s JIT support most of the Python features and NumPy features. This means that this customization interface can be very general to apply on almost all customized computations you want.

Customize a GPU operator#

Customizing operators for GPU devices is extremely hard. We are still working on it. But it will come soon.

Currently, we support to apply CPU function of the operator to the GPU. This is controlled by apply_cpu_func_to_gpu=True setting during the operator registration. When turn on this option, the input data on the GPU will move to the host CPU for computing. Then the results in the CPU device will be moved back to GPU for other computations.

op3 = bm.register_op(name='add2',
                     cpu_func=con_compute,
                     eval_shape=abs_eval,
                     apply_cpu_func_to_gpu=True)

Benchmarking the customized operator performance#

To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use event_sum as an example. The implementation of event_sum by using our customization is shown as below:

Operator customized by using the Python syntax.

class EventSum(bm.XLACustomOp):
  """Customized operator."""

  def __init__(self):

    def abs_eval(events, indices, indptr, post_val, values):
      return post_val

    def con_compute(outs, ins):
      post_val = outs
      events, indices, indptr, _, values = ins
      for i in range(events.size):
        if events[i]:
          for j in range(indptr[i], indptr[i + 1]):
            index = indices[j]
            old_value = post_val[index]
            post_val[index] = values + old_value

    super(EventSum, self).__init__(eval_shape=abs_eval,
                                   con_compute=con_compute)


event_sum = EventSum()

The Exponential synapse model which is implemented through the above Python level operator.

class ExponentialV2(bp.dyn.TwoEndConn):
  """Exponential synapse model using customized operator written in C++."""

  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
    super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto')

  def update(self, tdi):
    self.g.value = self.integral(self.g, tdi.t, tdi.dt)
    self.g += event_sum(self.pre.spike, self.pre2post[0], self.pre2post[1],
                        bm.zeros(self.post.num), self.g_max)
    self.post.input += self.g * (self.E - self.post.V)

The Exponential synapse model which is implemented through the C++ build-in operator.

class ExponentialV1(bp.dyn.TwoEndConn):
  """Exponential synapse model using customized operator written in C++."""

  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
    super(ExponentialV1, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto')

  def update(self, tdi):
    self.g.value = self.integral(self.g, tdi.t, tdi.dt)
    self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
    self.post.input += self.g * (self.E - self.post.V)

The E/I balanced network model.

class EINet(bp.dyn.Network):
  def __init__(self, scale, syn_type='v1'):
    syn_cls = ExponentialV1 if syn_type == 'v1' else ExponentialV2

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                V_initializer=bp.init.Normal(-55., 2.))
    E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto')
    I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto')

    # synapses
    E2E = syn_cls(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
    E2I = syn_cls(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
    I2E = syn_cls(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
    I2I = syn_cls(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)

Let’s compare the speed results.

net1 = EINet(scale=10., syn_type='v1')
runner1 = bp.dyn.DSRunner(net1, inputs=[('E.input', 20.), ('I.input', 20.)])
t, _ = runner1.predict(10000., eval_time=True)
print("Operator implemented through C++ :", t)
Operator implemented through C++ : 12.278022527694702
net2 = EINet(scale=10., syn_type='v2')
runner2 = bp.dyn.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
t, _ = runner2.predict(10000., eval_time=True)
print('Operator implemented through Python: ', t)
Operator implemented through Python:  11.629684686660767

After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss.