Just-In-Time Compilation

One of the core ideas of BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed! Therefore, it is necessary to understand how to write codes which is compatible with the JIT concept.

More details please see our tutorials in “Math Foundation”.

import brainpy as bp
import brainpy.math as bm

JIT for Functions

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform your codes into machine codes.

Take the pure functions as the example. Here we try to implement a function of Gaussian Error Linear Unit,

def gelu(x):
sqrt = bm.sqrt(2 / bm.pi)
cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
y = x * cdf
return y

Let’s first try the running speed without JIT.

# without JIT

x = bm.random.random(100000)
%timeit gelu(x)
334 µs ± 5.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The function after JIT can significantly speed up.

# with JIT

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
71.9 µs ± 144 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

JIT for Objects

Moreover, in BrainPy, JIT compilation can be carried on the class objects, with the minimal constraints of:

1. The class object must be a subclass of brainpy.Base.

2. Dynamically changed variables must be labeled as brainpy.math.Variable.

3. Variable updating must be accomplished by in-place operations.

Let’s try a simple example, which trains a Logistic regression classifier.

class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()

# parameters
self.dimension = dimension

# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w.value = self.w - u

In this example, we want to train the model weights (self.w) again and again. So, we mark the weights as bm.Variable.

import time

def benckmark(model, points, labels, num_iter=30, name=''):
t0 = time.time()
for i in range(num_iter):
model(points, labels)

print(f'{name} used time {time.time() - t0} s')
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 6.86694073677063 s
# with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 4.27644157409668 s

It’s worthy to note that in the above LogisticRegression model, the dynamically changed variable (weight w) is marked as a bm.Variable. If not, in the compilation phase, all self. accessed variables which are not instances of bm.Variable will be compiled as static constants.

In-place operators

The updating of variables should be made in-place. There are several commonly used in-place operations.

v = bm.Variable(bm.arange(10))

v
Variable(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

1, Indexing and slicing. Like (More details please refer to Array Objects Indexing)

• Index: v[i] = a

• Slice: v[i:j] = b

• Slice the specific values: v[[1, 3]] = c

• Slice all values, v[:] = d, v[...] = e

v = 2.

v
Variable(DeviceArray([2, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

2, Augmented assignment. All augmented assignment are in-place operations, which include

• -= (subtract)

• /= (divide)

• *= (multiply)

• //= (floor divide)

• %= (modulo)

• **= (power)

• &= (and)

• |= (or)

• ^= (xor)

• <<= (left shift)

• >>= (right shift)

v += 1

v
Variable(DeviceArray([ 3,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32))

3, .value assignment.

v.value = bm.arange(10)

v
Variable(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

4, .update() method.

v.update(bm.ones_like(v))

v
Variable(DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32))