# Just-In-Time Compilation

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it’s hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on `class`

in Python will make your coding more readable, controlable,
flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

Here, in BrainPy, we provide JIT compilation interface for class objects, built on the top of JAX and Numba. In this section, we will talk about this.

```
[1]:
```

```
import brainpy as bp
bp.__version__
```

```
[1]:
```

```
'1.1.0rc1'
```

## JIT in Numba and JAX

Numba is specialized to optimize your native NumPy codes, including NumPy arrays, loops and condition controls, etc. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can just-in-time compile your native Python loops (`for`

or `while`

syntaxs) and condition controls (`if ... else ...`

). This means that it supports your intutive Python programming.

However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn’t support `one code runs on multiple devices`

. Same code cannot run on GPU targets.

JAX is a rising-star JIT compiler in Python scientific computing. It uses XLA to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance.

However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on WSL or compile JAX from source. Further, the coding in JAX is not very intutive. For example,

Doesn’t support in-place mutating updates of arrays, like

`x[i] += y`

, instead you should use`x = jax.ops.index_update(x, i, y)`

Doesn’t support JIT compilation of your native loops and conditions, like

arr = np.zeros(5) for i in range(arr.shape[0]): arr[i] += 2. if i % 2 == 0: arr[i] += 1.

instead you should use

arr = np.zeros(5) def loop_body(i, acc_arr): arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) return jax.lax.cond(i % 2 == 0, arr1, lambda arr1: ops.index_update(arr1, i, arr1[i] + 1), arr1, lambda arr1: arr1) arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)

What’s more, both frameworks have poor support on class objects.

## JIT compilation in BrainPy

In order to obtain an *intutive*, *flexible* and *high-performance* framework for brain modeling, in BrainPy, we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it).

Specifically, we provide BrainPy math module for

flexible switch between NumPy/Numba and JAX backends

unified numpy-like array operations

unified

`ndarray`

data structure which supports in-place updateunified

`random`

APIspowerful

`jit()`

compilation which supports functions and class objects both

### Backend Switch

The advantages and disadvantages of Numba and JAX are listed in above. We support them both for different models.

If you are coding a small network model, NumPy/Numba backend may be very suitable for you. You can switch to this backend by:

```
[2]:
```

```
# switch to NumPy backend
bp.math.use_backend('numpy')
bp.math.get_backend_name()
```

```
[2]:
```

```
'numpy'
```

Actually, “numpy” is the default backend used in BrainPy. However, if you are coding a large-scale network model, or try to run on GPUs or TPUs, please switch to JAX backend by:

```
[3]:
```

```
# switch to JAX backend
bp.math.use_backend('jax')
bp.math.get_backend_name()
```

```
[3]:
```

```
'jax'
```

In BrainPy, “numpy” and “jax” backends are interchangeable. Both backends have the same APIs, and same codes can run on both backends (except `for ...`

and `if ... else ...`

in JAX backend, we are trying to solve this problem).

### Math Operations

The APIs in `brainpy.math`

module in each backend is much similar to APIs in original `numpy`

. The detailed comparison please see the Comparison Table.

For example, the **array creation** functions,

```
[4]:
```

```
bp.math.zeros((10, 3))
```

```
[4]:
```

```
JaxArray(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))
```

```
[5]:
```

```
bp.math.arange(10)
```

```
[5]:
```

```
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
```

```
[6]:
```

```
x = bp.math.array([[1,2], [3,4]])
x
```

```
[6]:
```

```
JaxArray(DeviceArray([[1, 2],
[3, 4]], dtype=int32))
```

The **array manipulation** functions:

```
[7]:
```

```
bp.math.max(x)
```

```
[7]:
```

```
DeviceArray(4, dtype=int32)
```

```
[8]:
```

```
bp.math.repeat(x, 2)
```

```
[8]:
```

```
JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))
```

```
[9]:
```

```
bp.math.repeat(x, 2, axis=1)
```

```
[9]:
```

```
JaxArray(DeviceArray([[1, 1, 2, 2],
[3, 3, 4, 4]], dtype=int32))
```

The **random numbers** generation functions:

```
[10]:
```

```
bp.math.random.random((3, 5))
```

```
[10]:
```

```
JaxArray(DeviceArray([[0.3719796 , 0.40873682, 0.92993236, 0.9059397 , 0.6716608 ],
[0.84367204, 0.03071105, 0.36278927, 0.52648306, 0.64811254],
[0.57296276, 0.76439524, 0.05158341, 0.947039 , 0.640892 ]], dtype=float32))
```

```
[11]:
```

```
y = bp.math.random.normal(loc=0.0, scale=2.0, size=(2, 5))
y
```

```
[11]:
```

```
JaxArray(DeviceArray([[ 1.4749131 , 0.6137249 , -0.70614064, 2.148907 ,
-0.10390168],
[-0.11172882, -2.085571 , -2.9504242 , 1.7263914 ,
3.2560937 ]], dtype=float32))
```

The **linear algebra** functions:

```
[12]:
```

```
bp.math.dot(x, y)
```

```
[12]:
```

```
JaxArray(DeviceArray([[ 1.2514555, -3.5574172, -6.606989 , 5.60169 ,
6.4082856],
[ 3.9778242, -6.5011096, -13.920118 , 13.352286 ,
12.71267 ]], dtype=float32))
```

```
[13]:
```

```
bp.math.linalg.eig(x)
```

```
[13]:
```

```
(JaxArray(DeviceArray([-0.37228107+0.j, 5.3722816 +0.j], dtype=complex64)),
JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
[ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))
```

The **discrete fourier transform** functions:

```
[14]:
```

```
bp.math.fft.fft(bp.math.exp(2j * bp.math.pi * bp.math.arange(8) / 8))
```

```
[14]:
```

```
JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j, 8.0000000e+00+4.8023384e-07j,
-3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
-3.8941437e-07-2.0663207e-07j, 2.3841858e-07-1.9411573e-07j,
3.8941437e-07-2.0663207e-07j, 1.6858739e-07+3.1786506e-08j], dtype=complex64))
```

```
[15]:
```

```
bp.math.fft.ifft(bp.math.array([0, 4, 0, 0]))
```

```
[15]:
```

```
JaxArray(DeviceArray([ 1.+0.j, 0.+1.j, -1.+0.j, 0.-1.j], dtype=complex64))
```

The full list of API implementation please see the Comparison Table.

### JIT for Functions

To take advantage of the JIT compilation, users just need to wrap their customized *functions* or *objects* into bp.math.jit() to instruct BrainPy to transform your codes into machine codes.

Take the **pure functions** as the example. Here we try to implement a function of Gaussian Error Linear Unit,

```
[16]:
```

```
def gelu(x):
sqrt = bp.math.sqrt(2 / bp.math.pi)
cdf = 0.5 * (1.0 + bp.math.tanh(sqrt * (x + 0.044715 * (x ** 3))))
y = x * cdf
return y
```

```
[17]:
```

```
# jax backend, without JIT
x = bp.math.random.random(100000)
%timeit gelu(x)
```

```
292 µs ± 7.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

```
[18]:
```

```
# jax backend, with JIT
gelu_jit = bp.math.jit(gelu)
%timeit gelu_jit(x)
```

```
66.4 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```

```
[19]:
```

```
bp.math.use_backend('numpy')
```

```
[20]:
```

```
# numpy backend, without JIT
x = bp.math.random.random(100000)
%timeit gelu(x)
```

```
3.53 ms ± 55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

```
[21]:
```

```
# numpy backend, with JIT
gelu_jit = bp.math.jit(gelu)
%timeit gelu_jit(x)
```

```
1.91 ms ± 65.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

### JIT for Objects

Moreover, in BrainPy, JIT compilation can be carried on the **class objects**. Specifically, any instance of brainpy.Base object can be just-in-time compiled into machine codes.

Let’s try a simple example, which trains a Logistic regression classifier.

```
[22]:
```

```
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bp.math.Variable(2.0 * bp.math.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bp.math.dot(((1.0 / (1.0 + bp.math.exp(-Y * bp.math.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
num_dim, num_points = 10, 20000000
num_iter = 30
points = bp.math.random.random((num_points, num_dim))
labels = bp.math.random.random(num_points)
```

```
[23]:
```

```
# numpy backend, without JIT
lr1 = LogisticRegression(num_dim)
lr1(points, labels)
import time
t0 = time.time()
for i in range(num_iter):
lr1(points, labels)
print(f'Logistic Regression model without jit used time {time.time() - t0} s')
```

```
Logistic Regression model without jit used time 19.143301725387573 s
```

```
[24]:
```

```
# numpy backend, with JIT
lr2 = LogisticRegression(num_dim)
jit_lr2 = bp.math.jit(lr2)
jit_lr2(points, labels) # first call is the compiling
t0 = time.time()
for i in range(num_iter):
jit_lr2(points, labels)
print(f'Logistic Regression model with jit used time {time.time() - t0} s')
```

```
Logistic Regression model with jit used time 11.75181531906128 s
```

```
[25]:
```

```
# numpy backend, with JIT + parallel
lr3 = LogisticRegression(num_dim)
jit_lr3 = bp.math.jit(lr3, parallel=True)
jit_lr3(points, labels) # first call is the compiling
t0 = time.time()
for i in range(num_iter):
jit_lr3(points, labels)
print(f'Logistic Regression model with jit+parallel used time {time.time() - t0} s')
```

```
Logistic Regression model with jit+parallel used time 7.351796865463257 s
```

What’s worth noting here is that:

The dynamically changed variable (weight

`w`

) is marked as a bp.math.Variable in`__init__()`

function.The variable

`w`

is in-place updated with`[:]`

indexing in`__call__()`

function.

The above two things are all things that are *special* in the JIT compilation of class objects. Other operations and coding styles are the same with class objects without JIT acceleration.

#### Mechanism of JIT in NumPy backend

So, **why must we in-place update the dynamically changed variables?**

First of all, in the compilation phase, a

`self.`

accessed variable which is not an instance of`bp.math.Variable`

will be compiled as a static constant. For example,`self.a = 1.`

will be compiled as a constant`1.`

. If you try to change the value of`self.a`

, it will not work.

```
[26]:
```

```
class Demo1(bp.Base):
def __init__(self):
super(Demo1, self).__init__()
self.a = 1.
def update(self, b):
self.a = b
d1 = Demo1()
bp.math.jit(d1.update)(2.)
print(d1.a)
```

```
1.0
```

Second, all the variables you want to change during the function call must be labeled as

`bp.math.Variable`

. Then during the JIT compilation period, these variables will be recompiled as arguments of the jitted functions.

```
[27]:
```

```
class Demo2(bp.Base):
def __init__(self):
super(Demo2, self).__init__()
self.a = bp.math.Variable(1.)
def update(self, b):
self.a = b
bp.math.jit(Demo2().update, show_code=True)
```

```
The recompiled function:
-------------------------
def update(b, Demo20_a=None):
Demo20_a = b
The namespace of the above function:
------------------------------------
{}
The recompiled function:
-------------------------
def new_update(b):
update(b, Demo20_a=Demo20.a.value,)
The namespace of the above function:
------------------------------------
{'Demo20': <__main__.Demo2 object at 0x7f588c079f40>,
'update': CPUDispatcher(<function update at 0x7f58167fbf70>)}
```

```
[27]:
```

```
<function new_update(b)>
```

The original `Demo2.update`

function is recompiled as `update()`

function, with the dynamical variable `a`

compiled as an argument `Demo20_a`

. Then, during the functional call (in the `new_update()`

function), `Demo20.a.value`

is passed to `Demo20_a`

for the jitted `update()`

function.

Third, as you can notice in the above source code of the recompiled function, the recompiled variable

`Demo20_a`

does not return. This means once the function finished running, the computed value will disappear. Therefore, the dynamically changed variables must be in-place updated to hold their updated values.

```
[28]:
```

```
class Demo3(bp.Base):
def __init__(self):
super(Demo3, self).__init__()
self.a = bp.math.Variable(1.)
def update(self, b):
self.a[...] = b
d3 = Demo3()
bp.math.jit(d3.update)(2.)
d3.a
```

```
[28]:
```

```
Variable(2.)
```

The above simple demonstrations illustrate the core mechanism of the JIT compilation in NumPy backend. bp.math.jit() in NumPy backend can recursively compile your class objects. So, please try you models, and run it under the JIT accelerations.

The mechanism of JIT compilation of JAX backend is quite different. We will detail this in th upcoming tutorials.

#### In-place operators

In the next, what’s the most important question is: **what are in-place operators?**

```
[29]:
```

```
v = bp.math.arange(10)
id(v)
```

```
[29]:
```

```
140016271963120
```

Actually, in-place operators include the following operations:

**Indexing and slicing**. Like (More details please refer to Array Objects Indexing)

Index:

`v[i] = a`

Slice:

`v[i:j] = b`

Slice the specific values:

`v[[1, 3]] = c`

Slice all values,

`v[:] = d`

,`v[...] = e`

```
[30]:
```

```
v[0] = 1
id(v)
```

```
[30]:
```

```
140016271963120
```

```
[31]:
```

```
v[1: 2] = 1
id(v)
```

```
[31]:
```

```
140016271963120
```

```
[32]:
```

```
v[[1, 3]] = 2
id(v)
```

```
[32]:
```

```
140016271963120
```

```
[33]:
```

```
v[:] = 0
id(v)
```

```
[33]:
```

```
140016271963120
```

```
[34]:
```

```
v[...] = bp.math.arange(10)
id(v)
```

```
[34]:
```

```
140016271963120
```

**Augmented assignment**. All augmented assignment are in-place operations, which include

`+=`

(add)`-=`

(subtract)`/=`

(divide)`*=`

(multiply)`//=`

(floor divide)`%=`

(modulo)`**=`

(power)`&=`

(and)`|=`

(or)`^=`

(xor)`<<=`

(left shift)`>>=`

(right shift)

```
[35]:
```

```
v += 1
id(v)
```

```
[35]:
```

```
140016271963120
```

```
[36]:
```

```
v *= 2
id(v)
```

```
[36]:
```

```
140016271963120
```

```
[37]:
```

```
v |= bp.math.random.randint(0, 2, 10)
id (v)
```

```
[37]:
```

```
140016271963120
```

```
[38]:
```

```
v **= 2.
id(v)
```

```
[38]:
```

```
140016271963120
```

```
[39]:
```

```
v >>= 2
id(v)
```

```
[39]:
```

```
140016271963120
```

More advanced usage please see our forthcoming tutorials.

Chaoming Wang (adaduo@outlook.com)

Update at 2021.09.06