Control Flows for JIT compilation#

Colab Open in Kaggle

@Chaoming Wang

Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.

Python has two types of control structures:

  • Selection: used for decisions and branching.

  • Repetition: used for looping, i.e., repeating a piece of code multiple times.

In this section, we are going to talk about how to build effective control flows in the context of JIT compilation.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
[Taichi] version 1.7.0, llvm 15.0.1, commit 37b8e80c, win, python 3.11.5
bp.__version__
'2.4.6'

1. Selection#

In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:

  • if-else

  • nested if

  • if-elif-else

Non-Variable-based control statements#

Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,

class OddEven(bp.BrainPyObject):
    def __init__(self, type_=1):
        super(OddEven, self).__init__()
        self.type_ = type_
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.type_ == 1:
            self.a += 1
        elif self.type_ == 2:
            self.a -= 1
        else:
            raise ValueError(f'Unknown type: {self.type_}')
        return self.a

In the above example, the target statement in if (statement) syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.

model = bm.jit(OddEven(type_=1))

model()
Variable(value=Array([1.]), dtype=float32)
model = bm.jit(OddEven(type_=2))

model()
Variable(value=Array([-1.]), dtype=float32)
try:
    model = bm.jit(OddEven(type_=3))
    model()
except ValueError as e:
    print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3

Variable-based control statements#

However, if the statement target in a if ... else ... syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.

class OddEvenCauseError(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCauseError, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.rand < 0.5:  self.a += 1
        else:  self.a -= 1
        return self.a
wrong_model = bm.jit(OddEvenCauseError())

try:
    wrong_model()
except Exception as e:
    print(f"{e.__class__.__name__}: {str(e)}")
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..
The error occurred while tracing the function <unknown> for eval_shape. This value became a tracer due to JAX operations on these lines:

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line D:\codes\projects\brainpy-chaoming0625\brainpy\_src\math\ndarray.py:267:19 (__lt__)

  operation a:bool[1] = lt b c
    from line D:\codes\projects\brainpy-chaoming0625\brainpy\_src\math\ndarray.py:267:19 (__lt__)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):

brainpy.math.where#

where(condition, x, y) function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.

a = 1.
bm.where(a < 0, 0., 1.)
Array(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
Array(value=Array([0., 1., 0., 0., 1.]), dtype=float32)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
Array(value=Array([[0., 0., 1.],
                   [0., 0., 0.],
                   [0., 1., 1.]]),
      dtype=float32)

For the above example, we can rewrite it by using where syntax as:

class OddEvenWhere(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenWhere, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.where(self.rand < 0.5, 1., -1.)
        return self.a
model = bm.jit(OddEvenWhere())
model()
Variable(value=Array([1.]), dtype=float32)

brainpy.math.ifelse#

Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.

In its simplest case, brainpy.math.ifelse(condition, branches, operands) is equivalent to:

def ifelse(condition, branches, operands):
  true_fun, false_fun = branches
  if condition:
    return true_fun(operands)
  else:
    return false_fun(operands)

Based on this function, we can rewrite the above example by using cond syntax as:

class OddEvenCond(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCond, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.ifelse(self.rand[0] < 0.5,
                            [lambda: 1., lambda: -1.])
        return self.a
model = bm.jit(OddEvenCond())
model()
Variable(value=Array([-1.]), dtype=float32)

If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands) can also help you accomplish this easily. Actually, multiple branching case is equivalent to:

def ifelse(conditions, branches, operands):
  pred1, pred2, ... = conditions
  func1, func2, ..., funcN = branches
  if pred1:
    return func1(operands)
  elif pred2:
    return func2(operands)
  ...
  else:
    return funcN(operands)

For example, if you have the following code:

def f(a):
  if a > 10:
    return 1.
  elif a > 5:
    return 2.
  elif a > 0:
    return 3.
  elif a > -5:
    return 4.
  else:
    return 5.

It can be expressed as:

def f(a):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[1., 2., 3., 4., 5.])
f(11.)
Array(1., dtype=float32, weak_type=True)
f(6.)
Array(2., dtype=float32, weak_type=True)
f(1.)
Array(3., dtype=float32, weak_type=True)
f(-4.)
Array(4., dtype=float32, weak_type=True)
f(-6.)
Array(5., dtype=float32, weak_type=True)

A more complex example is:

def f2(a, x):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[lambda x: x*2,
                             2.,
                             lambda x: x**2 -1,
                             lambda x: x - 4.,
                             5.],
                   operands=x)
f2(11, 1.)
Array(2., dtype=float32, weak_type=True)
f2(6, 1.)
Array(2., dtype=float32, weak_type=True)
f2(1, 1.)
Array(0., dtype=float32, weak_type=True)
f2(-4, 1.)
Array(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
Array(5., dtype=float32, weak_type=True)

If instances of brainpy.math.Variable are used in branching functions,

a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f():  a.value += 1
def false_f(): b.value -= 1

bm.ifelse(True, [true_f, false_f])
bm.ifelse(False, [true_f, false_f])

print('a:', a)
print('b:', b)
a: Variable(value=Array([1., 1.]), dtype=float32)
b: Variable(value=Array([0., 0.]), dtype=float32)

2. Repetition#

A repetition statement is used to repeat a group(block) of programming instructions.

In Python, we generally have two loops/repetitive statements:

  • for loop: Execute a set of statements once for each item in a sequence.

  • while loop: Execute a block of statements repeatedly until a given condition is satisfied.

Pythonic loop syntax#

Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,

class LoopSimple(bp.BrainPyObject):
    def __init__(self):
        super(LoopSimple, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = bm.Variable(rng.random(1000))
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        for s in self.seq:
            self.res += s
        return self.res.value
import time

def measure_time(f, return_res=False, verbose=True):
    t0 = time.time()
    r = f()
    t1 = time.time()
    if verbose:
        print(f'Result: {r}, Time: {t1 - t0}')
    return r if return_res else None
model = bm.jit(LoopSimple())

# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 1.4419348239898682
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0

When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.

JAX has provided several important loop syntax, including:

BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable. Specifically, they are:

In this section, we only talk about how to use our provided loop functions.

brainpy.math.for_loop()#

brainpy.math.for_loop() is used to generate a for-loop function when you use Variable.

Suppose that you are using several Variables to implement your body function “body_fun”. Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as


def for_loop_function(body_fun, xs):
  ys = []
  for x in xs:
    results = body_fun(x)
    ys.append(results)
  return ys

In BrainPy, you can define this logic using brainpy.math.for_loop():

import brainpy.math

hist_of_out_vars = brainpy.math.for_loop(body_fun, operands)

For the above example, we can rewrite it by using brainpy.math.for_loop as:

class LoopStruct(bp.BrainPyObject):
    def __init__(self):
        super(LoopStruct, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        def add(s):
          self.res += s
          return self.res.value

        return bm.for_loop(body_fun=add, operands=self.seq)
model = bm.jit(LoopStruct())

r = measure_time(model, verbose=False, return_res=True)
r.shape
(1000, 1)

In essence, body_fun defines the one-step updating rule of how variables are updated. All returns of body_fun will be gathered as the history values. operands specified the inputs of the body_fun. It will be looped over the fist axis.

brainpy.math.while_loop()#

brainpy.math.while_loop() is used to generate a while-loop function when you use Varible. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.while_loop() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which receives the old values at the latest step and returns the updated values at the current step:


while cond_fun(x):
    x = body_fun(x)

Note the difference between brainpy.math.for_loop and brainpy.math.while_loop:

  1. The returns of brainpy.math.for_loop are the values to be gathered as the history values. While the returns of brainpy.math.while_loop should be the same shape and type with its inputs, because they are represented as the updated values.

  2. brainpy.math.for_loop can receive anything without explicit requirements of returns. But, brainpy.math.while_loop should return what it receives.

A concreate example of brainpy.math.while_loop is as the follows:

i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))

def cond_f():
    return i[0] < 10

def body_f():
    i.value += 1.
    counter.value += i

bm.while_loop(body_f, cond_f, operands=())

print(counter, i)
Variable(value=Array([55.]), dtype=float32) Variable(value=Array([10.]), dtype=float32)

In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i and counter.

Or, similarly,

i = bm.Variable(bm.zeros(1))

def cond_f(counter):
    return i[0] < 10

def body_f(counter):
    i.value += 1.
    return counter + i[0]

bm.while_loop(body_f, cond_f, operands=(1., ))
(Array(56., dtype=float32),)