Compilation

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about code compilation that can accelerate your model running performance.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.math.jit()

JAX provides JIT compilation jax.jit() for pure functions.In most cases, however, we code with Python classes. brainpy.math.jit() is intended to extend just-in-time compilation to class objects.

JIT compilation for class objects

The constraints for class-object JIT ciompilation include:

  • The JIT target must be a subclass of brainpy.Base.

  • Dynamically changed variables must be labeled as brainpy.math.Variable.

  • Updating Variables must be accomplished by in-place operations.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters
        self.dimension = dimension

        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u 
        # The above line can also be expressed as: 
        # 
        #   self.w.value = self.w - u 
        # 
        # or, 
        # 
        #   self.w.update(self.w - u)

In this example, weight self.w is a dynamically changed variable, thus marked as Variable. During the update phase __call__(), self.w is in-place updated through self.w[:] = .... Alternatively, one can replace the data in the variable by self.w.value = ... or self.w.update(...).

Now this logistic regression can be accelerated by JIT compilation.

num_dim, num_points = 10, 200000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr = LogisticRegression(10)

%timeit lr(points, labels)
3.11 ms ± 98.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
lr_jit = bm.jit(lr)

%timeit lr_jit(points, labels)
1.54 ms ± 25.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

JIT mechanism

The mechanism of JIT compilation is that BrainPy automatically transforms your class methods into functions.

brainpy.math.jit() receives a dyn_vars argument, which denotes the dynamically changed variables. If it is not provided by users, BrainPy will automatically detect them by calling Base.vars() (only variables labeled as Variable will be automatically retrieved by Base.vars()). Once receiving “dyn_vars”, BrainPy will treat “dyn_vars” as function arguments and then transform class objects into functions.

import types

isinstance(lr_jit, types.FunctionType)  # "lr" is class, while "lr_jit" is a function
True

Therefore, the secrete of brainpy.math.jit() is providing “dyn_vars”. No matter your target is a class object, a method in the class object, or a pure function, if there are dynamically changed variables, you just pack them into brainpy.math.jit() as “dyn_vars”. Then, all the compilation and acceleration will be handled by BrainPy automatically. Let’s illustrate this by several examples.

Example 1: JIT compiled methods in a class

In this example, we try to run a method just-in-time in a class, in which the object variable are used to compute the final results.

class Linear(bp.Base):
    def __init__(self, n_in, n_out):
        super(Linear, self).__init__()
        self.w = bm.random.random((n_in, n_out))
        self.b = bm.zeros(n_out)
    
    def update(self, x):
        return x @ self.w + self.b
x = bm.zeros(10)  # the input data
l = Linear(10, 3)  # the class we need

First, we mark “w” and “b” as dynamically changed variables. Changing “w” or “b” will change the final results.

update1 = bm.jit(
    l.update, dyn_vars=[l.w, l.b]  # make 'w' and 'b' dynamically change
)  

update1(x)  # x is 0., b is 0., therefore y is 0.
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
l.b[:] = 1.  # change b to 1, we expect y will be 1 too

update1(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))

This time, we only mark “w” as a dynamically changed variable. We will find that no matter how “b” is modified, the results will not change.

update2 = bm.jit(
    l.update, dyn_vars=[l.w]  # only make 'w' dynamically change
)

update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
l.b[:] = 2.  # change b to 2

update2(x)  # while y will not be 2
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))

Example 2: JIT compiled functions

Now, we change the above “Linear” object to a function.

n_in = 10;  n_out = 3

w = bm.random.random((n_in, n_out))
b = bm.zeros(n_out)

def update(x):
    return x @ w + b

If we do not provide dyn_vars, “w” and “b” will be compiled as constant values.

update1 = bm.jit(update)
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
b[:] = 1.  # modify the value of 'b' will not 
           # change the result, because in the 
           # jitted function, 'b' is already 
           # a constant
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))

Providing “w” and “b” as dyn_vars will make them dynamically changed again.

update2 = bm.jit(update, dyn_vars=(w, b))
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
b[:] = 2.  # change b to 2, while y will not be 2
update2(x)
JaxArray(DeviceArray([2., 2., 2.], dtype=float32))

Example 3: JIT compiled neural networks

Now, let’s use SGD to train a neural network with JIT acceleration. Here we use the autograd function brainpy.math.grad(), which will be discussed in detail in the next section.

class LinearNet(bp.Base):
    def __init__(self, n_in, n_out):
        super(LinearNet, self).__init__()

        # weights
        self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
        self.b = bm.TrainVar(bm.zeros(n_out))
        self.r = bm.TrainVar(bm.random.random((n_out, 1)))
    
    def update(self, x):
        h = x @ self.w + self.b
        return h @ self.r
    
    def loss(self, x, y):
        predict = self.update(x)
        return bm.mean((predict - y) ** 2)


ln = LinearNet(100, 200)

# provide the variables want to update
opt = bm.optimizers.SGD(lr=1e-6, train_vars=ln.vars()) 

# provide the variables require graidents
f_grad = bm.grad(ln.loss, grad_vars=ln.vars(), return_value=True)  


def train(X, Y):
    grads, loss = f_grad(X, Y)
    opt.update(grads)
    return loss

# JIT the train function 
train_jit = bm.jit(train, dyn_vars=ln.vars() + opt.vars())
xs = bm.random.random((1000, 100))
ys = bm.random.random((1000, 1))

for i in range(30):
    loss  = train_jit(xs, ys)
    print(f'Train {i}, loss = {loss:.2f}')
Train 0, loss = 6542776.00
Train 1, loss = 3632715.50
Train 2, loss = 2029160.00
Train 3, loss = 1137243.50
Train 4, loss = 638561.00
Train 5, loss = 358928.81
Train 6, loss = 201870.97
Train 7, loss = 113577.14
Train 8, loss = 63915.05
Train 9, loss = 35973.81
Train 10, loss = 20250.74
Train 11, loss = 11402.26
Train 12, loss = 6422.33
Train 13, loss = 3619.55
Train 14, loss = 2042.07
Train 15, loss = 1154.22
Train 16, loss = 654.50
Train 17, loss = 373.25
Train 18, loss = 214.94
Train 19, loss = 125.85
Train 20, loss = 75.70
Train 21, loss = 47.47
Train 22, loss = 31.59
Train 23, loss = 22.65
Train 24, loss = 17.61
Train 25, loss = 14.78
Train 26, loss = 13.19
Train 27, loss = 12.29
Train 28, loss = 11.78
Train 29, loss = 11.50

RandomState

We have talked about RandomState in the Variables section. RandomeState is also a Variable. Therefore, ifthe default RandomState (brainpy.math.random.DEFAULT) is used in your function, you should mark it as one of the dyn_vars in the function. Otherwise, they will be treated as constants and the jitted function will always return the same value.

def function():
    return bm.random.normal(0, 1, size=(10,))
f1 = bm.jit(function)

f1() == f1()
JaxArray(DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                       True,  True], dtype=bool))

The correct way to make JIT for this function is:

bm.random.seed(1234)

f2 = bm.jit(function, dyn_vars=bm.random.DEFAULT)

f2() == f2()
JaxArray(DeviceArray([False, False, False, False, False, False, False, False,
                      False, False], dtype=bool))

Static arguments

Static arguments are treated as static/constant in the jitted function.

Two things must be marked as static: numerical arguments used in the conditional syntax (bool values or resulting in bool values) and strings. Otherwise, an error will raise.

@bm.jit
def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x
try:
    f(1.)
except Exception as e:
    print(type(e), e)
<class 'jax._src.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at <ipython-input-23-fb33d5d11189>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Simply speaking, arguments resulting in boolean values must be declared as static arguments. In brainpy.math.jit() function, we can set the names of static arguments.

def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x

f_jit = bm.jit(f, static_argnames=('x',))
f_jit(x=1.)
DeviceArray(3., dtype=float32, weak_type=True)

However, it’s worth noting that calling the jitted function with different values for these static arguments will trigger recompilation. Therefore, declaring static arguments may be suitable to the following situations:

  1. Boolean arguments.

  2. Arguments that only have several possible values.

If the argument value change significantly, you’d better not declare it as static.

For more information, please refer to the jax.jit API.