Just-In-Time Compilation#

@Chaoming Wang @Xiaoyu Chen

One of the core ideas of BrainPy is Just-In-Time (JIT) compilation. JIT compilation enables Python codes to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine-code speed, which will not only compensate for the time spent for code transformation but also save more time. Therefore, it is necessary to understand how to code in a JIT compatible environment.

This section will briefly introduce JIT compilation and its relation to BrainPy. For more details such as the JIT mechanism in BrainPy, please refer to the advanced Compilation tutorial.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

JIT Compilation for Functions#

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform Python code into machine code.

Take the pure functions as an example. Here we try to implement a function of Gaussian Error Linear Unit:

def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let’s first try to run the function without JIT.

x = bm.random.random(100000)
%timeit gelu(x)
295 µs ± 3.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

After JIT compilation, the function significantly speeds up.

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

JIT Compilation for Objects#

JIT compilation for functions is not enough for brain dynamics programming, since a multitude of dynamic variables and differential equations in a large system would make computation surprisingly complicated. Therefore, BrainPy enables JIT compilation to be performed on class objects, as long as users comply with the following rules:

  1. The class object must be a subclass of brainpy.Base.

  2. Dynamically changed variables must be labeled as brainpy.math.Variable.

  3. Variable updating must be accomplished by in-place operations.

Below is a simple example of a Logistic regression classifier. When wrapped into bm.jit(), the class oject will be JIT compiled.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w.value = self.w - u

In this example, the model weights (self.w) will be modified during training, so it is marked as bm.Variable. If not, in the compilation phase, all self. accessed variables which are not the instances of bm.Variable will be compiled as static constants.

import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):
        model(points, labels)

    print(f'{name} used time {time.time() - t0} s')
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 10.024710893630981 s
# with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 5.015154838562012 s

From the above example, we can appreciate the acceleration of JIT compilation. This example, however, is too simplified to show the great difference between running with and without JIT. In fact, in a large brain model, the acceleration brought by JIT compilation is usually far more significant.

Automatic JIT Compilation in Runners#

In a large dynamical system where a large number of neurons and synapses are defined, it would be a little tedious to explicitly wrap every object into bm.jit(). Fortunately, in most conditions, users do not need to call bm.jit(), as BrainPy will make JIT compilation automatically.

BrainPy provides a brainpy.Runner class that is inherited by various runners used in simulation, traning and integration. When initializing it, a runner receives a parameter named jit, which is set True by default. This suggests that Runner will automatically JIT compile the target oject as long as it is wrapped into the runner.

For example, when users perform dynamic simulation on a HH model, they first need to wrap the model into a simulation runner:

model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.))
runner(1000)  # running 1000 ms
0.6139698028564453

Where model is wrapped into a runner, and it will be JIT compiled during simulation.

If users do not want to use JIT compilation (JIT compilation prohibits Python debugging), they can turn it of by setting jit=False:

model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.), jit=False)
runner(1000)
258.76088523864746

The output is the time (s) spent on simulation. We can see that the simulation is much slower without JIT compilation.

Besides simulation, runners are also used by integrators and trainers. For more details, please refer to the tutorial of runners.