JIT Compilation

@Chaoming Wang

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it’s hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on class in Python will make your coding more readable, controlable, flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

Here, in BrainPy, we provide JIT compilation interface for class objects, built on the top of JAX and Numba. In this section, we will talk about this.

import brainpy as bp
import brainpy.math as bm

JIT in Numba and JAX

Numba is specialized to optimize your native NumPy codes, including NumPy arrays, loops and condition controls, etc. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can just-in-time compile your native Python loops (for or while syntaxs) and condition controls (if ... else ...). This means that it supports your intutive Python programming.

However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn’t support one code runs on multiple devices. Same code cannot run on GPU targets.

JAX is a rising-star JIT compiler in Python scientific computing. It uses XLA to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance.

However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on WSL or compile JAX from source. Further, the coding in JAX is not very intutive. For example,

  • Doesn’t support in-place mutating updates of arrays, like x[i] += y, instead you should use x = jax.ops.index_update(x, i, y)

  • Doesn’t support JIT compilation of your native loops and conditions, like

arr = np.zeros(5)
for i in range(arr.shape[0]):
    arr[i] += 2.
    if i % 2 == 0:
        arr[i] += 1.

instead you should use

arr = np.zeros(5)
def loop_body(i, acc_arr):
    arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
    return jax.lax.cond(i % 2 == 0, 
                        arr1,
                        lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
                        arr1,
                        lambda arr1: arr1)
arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)

What’s more, both frameworks have poor support on class objects.

JIT compilation in BrainPy

In order to obtain an intutive, flexible and high-performance framework for brain modeling, in BrainPy, we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it).

Specifically, we provide BrainPy math module for

  • flexible switch between NumPy (Numba) and JAX backends

  • unified numpy-like array operations

  • unified ndarray data structure which supports in-place update

  • unified random APIs

  • powerful jit() compilation which supports functions and class objects both

Backend Switch

To switch different backend, you can use:

# switch to NumPy backend
bm.use_backend('numpy')

bm.get_backend_name()
'numpy'
# switch to JAX backend
bm.use_backend('jax')

bm.get_backend_name()
'jax'

In BrainPy, “numpy” and “jax” backends are interchangeable. Both backends nearly have the same APIs, and same codes can run on both backends.

Math Operations

The APIs in brainpy.math module in each backend is much similar to APIs in original numpy. The detailed comparison please see the Comparison Table.

For example, the array creation functions,

bm.zeros((10, 3))
JaxArray(DeviceArray([[0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.]], dtype=float32))
bm.arange(10)
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
x = bm.array([[1,2], [3,4]])

x
JaxArray(DeviceArray([[1, 2],
                      [3, 4]], dtype=int32))

The array manipulation functions:

bm.max(x)
DeviceArray(4, dtype=int32)
bm.repeat(x, 2)
JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))
bm.repeat(x, 2, axis=1)
JaxArray(DeviceArray([[1, 1, 2, 2],
                      [3, 3, 4, 4]], dtype=int32))

The random numbers generation functions:

bm.random.random((3, 5))
JaxArray(DeviceArray([[0.77751863, 0.45367956, 0.09705102, 0.35475922, 0.6678729 ],
                      [0.81603515, 0.83256996, 0.29231536, 0.8294617 , 0.41171193],
                      [0.37354684, 0.8089644 , 0.4921714 , 0.93098676, 0.4895897 ]],            dtype=float32))
y = bm.random.normal(loc=0.0, scale=2.0, size=(2, 5))

y
JaxArray(DeviceArray([[ 0.24033605,  1.6801527 ,  1.6716577 ,  0.19926837,
                       -2.3778987 ],
                      [ 0.58184516,  1.3625289 , -2.3364332 ,  2.082281  ,
                        0.23409791]], dtype=float32))

The linear algebra functions:

bm.dot(x, y)
JaxArray(DeviceArray([[ 1.4040264,  4.4052105, -3.0012088,  4.3638306, -1.9097029],
                      [ 3.0483887, 10.490574 , -4.3307595,  8.926929 , -6.1973042]],            dtype=float32))
bm.linalg.eig(x)
(JaxArray(DeviceArray([-0.37228107+0.j,  5.3722816 +0.j], dtype=complex64)),
 JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
                       [ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))

The discrete fourier transform functions:

bm.fft.fft(bm.exp(2j * bm.pi * bm.arange(8) / 8))
JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j,  8.0000000e+00+4.8023384e-07j,
                      -3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
                      -3.8941437e-07-2.0663207e-07j,  2.3841858e-07-1.9411573e-07j,
                       3.8941437e-07-2.0663207e-07j,  1.6858739e-07+3.1786506e-08j],            dtype=complex64))
bm.fft.ifft(bm.array([0, 4, 0, 0]))
JaxArray(DeviceArray([ 1.+0.j,  0.+1.j, -1.+0.j,  0.-1.j], dtype=complex64))

The full list of API implementation please see the Comparison Table.

JIT for Functions

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform your codes into machine codes.

Take the pure functions as the example. Here we try to implement a function of Gaussian Error Linear Unit,

bm.use_backend('jax')


def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let’s first try the running speed without JIT.

# jax backend, without JIT

x = bm.random.random(100000)
%timeit gelu(x)
332 µs ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The function after JIT can significantly speed up.

# jax backend, with JIT

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66.6 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

After siwtching to the ‘numpy’ backend, the result also applies.

bm.use_backend('numpy')
# numpy backend, without JIT

x = bm.random.random(100000)
%timeit gelu(x)
3.52 ms ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# numpy backend, with JIT

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
1.91 ms ± 55.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JIT for Objects

However, in BrainPy, JIT compilation can be carried on the class objects. Specifically, any instance of brainpy.Base object can be just-in-time compiled into machine codes.

Let’s try a simple example, which trains a Logistic regression classifier.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000

In this example, we want to train the model weights (self.w) again and again. So, we mark the weights as bm.Variable.

import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):
        model(points, labels)

    print(f'{name} used time {time.time() - t0} s')

Let’s try to benchmark the ‘numpy’ backend.

bm.use_backend('numpy')

points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# numpy backend, without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 18.891679763793945 s
# numpy backend, with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 12.473472833633423 s

Same results can be obtained in ‘jax’ backend.

bm.use_backend('jax')

points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# jax backend, without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 6.245944976806641 s
# jax backend, with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 4.174307346343994 s

As you can see, the JIT function showes its speed advantage.

However, what’s worth noting here is that:

  1. The dynamically changed variable (weight w) is marked as a bm.Variable (in __init__() function).

  2. The variable w is in-place updated with [:] indexing (in __call__() function).

The above two things are all constraints in BrainPy for the JIT compilation of class objects. Other operations and coding styles are the same with the normal object-oriented programming in Python.

Mechanism of JIT in NumPy backend

bm.use_backend('numpy')

So, why must we in-place update the dynamically changed variables?

  • First of all, in the compilation phase, a self. accessed variable which is not an instance of bm.Variable will be compiled as a static constant. For example, self.a = 1. will be compiled as a constant 1.. If you try to change the value of self.a, it will not work.

class Demo1(bp.Base):
    def __init__(self):
        super(Demo1, self).__init__()
        
        self.a = 1.
    
    def update(self, b):
        self.a = b
        

d1 = Demo1()
bm.jit(d1.update)(2.)
print(d1.a)
1.0
  • Second, all the variables you want to change during the function call must be labeled as bm.Variable. Then during the JIT compilation period, these variables will be recompiled as arguments of the jitted functions.

class Demo2(bp.Base):
    def __init__(self):
        super(Demo2, self).__init__()
        
        self.a = bm.Variable(1.)
    
    def update(self, b):
        self.a = b
        

bm.jit(Demo2().update, show_code=True)
The recompiled function:
-------------------------

def update(b, Demo20_a=None):
    Demo20_a = b


The namespace of the above function:
{}

The recompiled function:
-------------------------
def new_update(b):
  update(b, 
		Demo20_a=Demo20.a.value,)

The namespace of the above function:
{'Demo20': <__main__.Demo2 object at 0x7fa9b055e2e0>,
 'update': CPUDispatcher(<function update at 0x7fa9d01ff160>)}
<function new_update(b)>

The original Demo2.update function is recompiled as update() function, with the dynamical variable a compiled as an argument Demo20_a. Then, during the functional call (in the new_update() function), Demo20.a.value is passed to Demo20_a for the jitted update() function.

  • Third, as you can notice in the above source code of the recompiled function, the recompiled variable Demo20_a does not return. This means once the function finished its running, the computed value will disappear. Therefore, the dynamically changed variables must be in-place updated to hold their updated values.

class Demo3(bp.Base):
    def __init__(self):
        super(Demo3, self).__init__()
        
        self.a = bm.Variable(1.)
    
    def update(self, b):
        self.a[...] = b
        

d3 = Demo3()
bm.jit(d3.update)(2.)
d3.a
Variable(2.)

The above simple demonstrations illustrate the core mechanism of the JIT compilation in NumPy backend. bm.jit() in NumPy backend can recursively compile your class objects. So, please try your models, and run it under the JIT accelerations.

However, the mechanism of JIT compilation of JAX backend is quite different. We will detail this in the upcoming tutorials.

In-place operators

In the next, it is important to answer: what are in-place operators?

v = bm.arange(10)

id(v)
140366785129408

Actually, in-place operators include the following operations:

  1. Indexing and slicing. Like (More details please refer to Array Objects Indexing)

  • Index: v[i] = a

  • Slice: v[i:j] = b

  • Slice the specific values: v[[1, 3]] = c

  • Slice all values, v[:] = d, v[...] = e

v[0] = 1

id(v)
140366785129408
v[1: 2] = 1

id(v)
140366785129408
v[[1, 3]] = 2

id(v)
140366785129408
v[:] = 0

id(v)
140366785129408
v[...] = bm.arange(10)

id(v)
140366785129408
  1. Augmented assignment. All augmented assignment are in-place operations, which include

  • += (add)

  • -= (subtract)

  • /= (divide)

  • *= (multiply)

  • //= (floor divide)

  • %= (modulo)

  • **= (power)

  • &= (and)

  • |= (or)

  • ^= (xor)

  • <<= (left shift)

  • >>= (right shift)

v += 1

id(v)
140366785129408
v *= 2

id(v)
140366785129408
v |= bm.random.randint(0, 2, 10)

id (v)
140366785129408
v **= 2.

id(v)
140366785129408
v >>= 2

id(v)
140366785129408

More advanced usage please see our forthcoming tutorials.