Just-In-Time Compilation

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it’s hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on class in Python will make your coding more readable, controlable, flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

Here, in BrainPy, we provide JIT compilation interface for class objects, built on the top of JAX and Numba. In this section, we will talk about this.

[1]:
import brainpy as bp

bp.__version__
[1]:
'1.1.0rc1'

JIT in Numba and JAX

Numba is specialized to optimize your native NumPy codes, including NumPy arrays, loops and condition controls, etc. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can just-in-time compile your native Python loops (for or while syntaxs) and condition controls (if ... else ...). This means that it supports your intutive Python programming.

However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn’t support one code runs on multiple devices. Same code cannot run on GPU targets.

JAX is a rising-star JIT compiler in Python scientific computing. It uses XLA to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance.

However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on WSL or compile JAX from source. Further, the coding in JAX is not very intutive. For example,

  • Doesn’t support in-place mutating updates of arrays, like x[i] += y, instead you should use x = jax.ops.index_update(x, i, y)

  • Doesn’t support JIT compilation of your native loops and conditions, like

    arr = np.zeros(5)
    for i in range(arr.shape[0]):
    arr[i] += 2.
    if i % 2 == 0:
        arr[i] += 1.
    

    instead you should use

    arr = np.zeros(5)
    def loop_body(i, acc_arr):
    arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
    return jax.lax.cond(i % 2 == 0,
                        arr1,
                        lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
                        arr1,
                        lambda arr1: arr1)
    arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)
    

What’s more, both frameworks have poor support on class objects.

JIT compilation in BrainPy

In order to obtain an intutive, flexible and high-performance framework for brain modeling, in BrainPy, we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it).

Specifically, we provide BrainPy math module for

  • flexible switch between NumPy/Numba and JAX backends

  • unified numpy-like array operations

  • unified ndarray data structure which supports in-place update

  • unified random APIs

  • powerful jit() compilation which supports functions and class objects both

Backend Switch

The advantages and disadvantages of Numba and JAX are listed in above. We support them both for different models.

If you are coding a small network model, NumPy/Numba backend may be very suitable for you. You can switch to this backend by:

[2]:
# switch to NumPy backend
bp.math.use_backend('numpy')

bp.math.get_backend_name()
[2]:
'numpy'

Actually, “numpy” is the default backend used in BrainPy. However, if you are coding a large-scale network model, or try to run on GPUs or TPUs, please switch to JAX backend by:

[3]:
# switch to JAX backend
bp.math.use_backend('jax')

bp.math.get_backend_name()
[3]:
'jax'

In BrainPy, “numpy” and “jax” backends are interchangeable. Both backends have the same APIs, and same codes can run on both backends (except for ... and if ... else ... in JAX backend, we are trying to solve this problem).

Math Operations

The APIs in brainpy.math module in each backend is much similar to APIs in original numpy. The detailed comparison please see the Comparison Table.

For example, the array creation functions,

[4]:
bp.math.zeros((10, 3))
[4]:
JaxArray(DeviceArray([[0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.]], dtype=float32))
[5]:
bp.math.arange(10)
[5]:
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
[6]:
x = bp.math.array([[1,2], [3,4]])

x
[6]:
JaxArray(DeviceArray([[1, 2],
                      [3, 4]], dtype=int32))

The array manipulation functions:

[7]:
bp.math.max(x)
[7]:
DeviceArray(4, dtype=int32)
[8]:
bp.math.repeat(x, 2)
[8]:
JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))
[9]:
bp.math.repeat(x, 2, axis=1)
[9]:
JaxArray(DeviceArray([[1, 1, 2, 2],
                      [3, 3, 4, 4]], dtype=int32))

The random numbers generation functions:

[10]:
bp.math.random.random((3, 5))
[10]:
JaxArray(DeviceArray([[0.3719796 , 0.40873682, 0.92993236, 0.9059397 , 0.6716608 ],
                      [0.84367204, 0.03071105, 0.36278927, 0.52648306, 0.64811254],
                      [0.57296276, 0.76439524, 0.05158341, 0.947039  , 0.640892  ]],            dtype=float32))
[11]:
y = bp.math.random.normal(loc=0.0, scale=2.0, size=(2, 5))

y
[11]:
JaxArray(DeviceArray([[ 1.4749131 ,  0.6137249 , -0.70614064,  2.148907  ,
                       -0.10390168],
                      [-0.11172882, -2.085571  , -2.9504242 ,  1.7263914 ,
                        3.2560937 ]], dtype=float32))

The linear algebra functions:

[12]:
bp.math.dot(x, y)
[12]:
JaxArray(DeviceArray([[  1.2514555,  -3.5574172,  -6.606989 ,   5.60169  ,
                         6.4082856],
                      [  3.9778242,  -6.5011096, -13.920118 ,  13.352286 ,
                        12.71267  ]], dtype=float32))
[13]:
bp.math.linalg.eig(x)
[13]:
(JaxArray(DeviceArray([-0.37228107+0.j,  5.3722816 +0.j], dtype=complex64)),
 JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
                       [ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))

The discrete fourier transform functions:

[14]:
bp.math.fft.fft(bp.math.exp(2j * bp.math.pi * bp.math.arange(8) / 8))
[14]:
JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j,  8.0000000e+00+4.8023384e-07j,
                      -3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
                      -3.8941437e-07-2.0663207e-07j,  2.3841858e-07-1.9411573e-07j,
                       3.8941437e-07-2.0663207e-07j,  1.6858739e-07+3.1786506e-08j],            dtype=complex64))
[15]:
bp.math.fft.ifft(bp.math.array([0, 4, 0, 0]))
[15]:
JaxArray(DeviceArray([ 1.+0.j,  0.+1.j, -1.+0.j,  0.-1.j], dtype=complex64))

The full list of API implementation please see the Comparison Table.

JIT for Functions

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bp.math.jit() to instruct BrainPy to transform your codes into machine codes.

Take the pure functions as the example. Here we try to implement a function of Gaussian Error Linear Unit,

[16]:
def gelu(x):
  sqrt = bp.math.sqrt(2 / bp.math.pi)
  cdf = 0.5 * (1.0 + bp.math.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y
[17]:
# jax backend, without JIT

x = bp.math.random.random(100000)
%timeit gelu(x)
292 µs ± 7.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
[18]:
# jax backend, with JIT

gelu_jit = bp.math.jit(gelu)
%timeit gelu_jit(x)
66.4 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
[19]:
bp.math.use_backend('numpy')
[20]:
# numpy backend, without JIT

x = bp.math.random.random(100000)
%timeit gelu(x)
3.53 ms ± 55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[21]:
# numpy backend, with JIT

gelu_jit = bp.math.jit(gelu)
%timeit gelu_jit(x)
1.91 ms ± 65.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JIT for Objects

Moreover, in BrainPy, JIT compilation can be carried on the class objects. Specifically, any instance of brainpy.Base object can be just-in-time compiled into machine codes.

Let’s try a simple example, which trains a Logistic regression classifier.

[22]:
class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters
        self.dimension = dimension

        # variables
        self.w = bp.math.Variable(2.0 * bp.math.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bp.math.dot(((1.0 / (1.0 + bp.math.exp(-Y * bp.math.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u


num_dim, num_points = 10, 20000000
num_iter = 30

points = bp.math.random.random((num_points, num_dim))
labels = bp.math.random.random(num_points)
[23]:
# numpy backend, without JIT

lr1 = LogisticRegression(num_dim)
lr1(points, labels)

import time
t0 = time.time()
for i in range(num_iter):
    lr1(points, labels)

print(f'Logistic Regression model without jit used time {time.time() - t0} s')
Logistic Regression model without jit used time 19.143301725387573 s
[24]:
# numpy backend, with JIT

lr2 = LogisticRegression(num_dim)
jit_lr2 = bp.math.jit(lr2)
jit_lr2(points, labels)  # first call is the compiling

t0 = time.time()
for i in range(num_iter):
    jit_lr2(points, labels)

print(f'Logistic Regression model with jit used time {time.time() - t0} s')
Logistic Regression model with jit used time 11.75181531906128 s
[25]:
# numpy backend, with JIT + parallel

lr3 = LogisticRegression(num_dim)
jit_lr3 = bp.math.jit(lr3, parallel=True)
jit_lr3(points, labels)  # first call is the compiling

t0 = time.time()
for i in range(num_iter):
    jit_lr3(points, labels)

print(f'Logistic Regression model with jit+parallel used time {time.time() - t0} s')
Logistic Regression model with jit+parallel used time 7.351796865463257 s

What’s worth noting here is that:

  1. The dynamically changed variable (weight w) is marked as a bp.math.Variable in __init__() function.

  2. The variable w is in-place updated with [:] indexing in __call__() function.

The above two things are all things that are special in the JIT compilation of class objects. Other operations and coding styles are the same with class objects without JIT acceleration.

Mechanism of JIT in NumPy backend

So, why must we in-place update the dynamically changed variables?

  • First of all, in the compilation phase, a self. accessed variable which is not an instance of bp.math.Variable will be compiled as a static constant. For example, self.a = 1. will be compiled as a constant 1.. If you try to change the value of self.a, it will not work.

[26]:
class Demo1(bp.Base):
    def __init__(self):
        super(Demo1, self).__init__()

        self.a = 1.

    def update(self, b):
        self.a = b


d1 = Demo1()
bp.math.jit(d1.update)(2.)
print(d1.a)
1.0
  • Second, all the variables you want to change during the function call must be labeled as bp.math.Variable. Then during the JIT compilation period, these variables will be recompiled as arguments of the jitted functions.

[27]:
class Demo2(bp.Base):
    def __init__(self):
        super(Demo2, self).__init__()

        self.a = bp.math.Variable(1.)

    def update(self, b):
        self.a = b


bp.math.jit(Demo2().update, show_code=True)
The recompiled function:
-------------------------

def update(b, Demo20_a=None):
    Demo20_a = b


The namespace of the above function:
------------------------------------
{}

The recompiled function:
-------------------------
def new_update(b):
  update(b, Demo20_a=Demo20.a.value,)

The namespace of the above function:
------------------------------------
{'Demo20': <__main__.Demo2 object at 0x7f588c079f40>,
 'update': CPUDispatcher(<function update at 0x7f58167fbf70>)}

[27]:
<function new_update(b)>

The original Demo2.update function is recompiled as update() function, with the dynamical variable a compiled as an argument Demo20_a. Then, during the functional call (in the new_update() function), Demo20.a.value is passed to Demo20_a for the jitted update() function.

  • Third, as you can notice in the above source code of the recompiled function, the recompiled variable Demo20_a does not return. This means once the function finished running, the computed value will disappear. Therefore, the dynamically changed variables must be in-place updated to hold their updated values.

[28]:
class Demo3(bp.Base):
    def __init__(self):
        super(Demo3, self).__init__()

        self.a = bp.math.Variable(1.)

    def update(self, b):
        self.a[...] = b


d3 = Demo3()
bp.math.jit(d3.update)(2.)
d3.a
[28]:
Variable(2.)

The above simple demonstrations illustrate the core mechanism of the JIT compilation in NumPy backend. bp.math.jit() in NumPy backend can recursively compile your class objects. So, please try you models, and run it under the JIT accelerations.

The mechanism of JIT compilation of JAX backend is quite different. We will detail this in th upcoming tutorials.

In-place operators

In the next, what’s the most important question is: what are in-place operators?

[29]:
v = bp.math.arange(10)

id(v)
[29]:
140016271963120

Actually, in-place operators include the following operations:

  1. Indexing and slicing. Like (More details please refer to Array Objects Indexing)

  • Index: v[i] = a

  • Slice: v[i:j] = b

  • Slice the specific values: v[[1, 3]] = c

  • Slice all values, v[:] = d, v[...] = e

[30]:
v[0] = 1

id(v)
[30]:
140016271963120
[31]:
v[1: 2] = 1

id(v)
[31]:
140016271963120
[32]:
v[[1, 3]] = 2

id(v)
[32]:
140016271963120
[33]:
v[:] = 0

id(v)
[33]:
140016271963120
[34]:
v[...] = bp.math.arange(10)

id(v)
[34]:
140016271963120
  1. Augmented assignment. All augmented assignment are in-place operations, which include

  • += (add)

  • -= (subtract)

  • /= (divide)

  • *= (multiply)

  • //= (floor divide)

  • %= (modulo)

  • **= (power)

  • &= (and)

  • |= (or)

  • ^= (xor)

  • <<= (left shift)

  • >>= (right shift)

[35]:
v += 1

id(v)
[35]:
140016271963120
[36]:
v *= 2

id(v)
[36]:
140016271963120
[37]:
v |= bp.math.random.randint(0, 2, 10)

id (v)
[37]:
140016271963120
[38]:
v **= 2.

id(v)
[38]:
140016271963120
[39]:
v >>= 2

id(v)
[39]:
140016271963120

More advanced usage please see our forthcoming tutorials.