Control Flows for JIT compilation#
Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.
Python has two types of control structures:
Selection: used for decisions and branching.
Repetition: used for looping, i.e., repeating a piece of code multiple times.
In this section, we are going to talk about how to build effective control flows in the context of JIT compilation.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.4.0'
1. Selection#
In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:
if-else
nested if
if-elif-else
Non-Variable
-based control statements#
Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,
class OddEven(bp.BrainPyObject):
def __init__(self, type_=1):
super(OddEven, self).__init__()
self.type_ = type_
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.type_ == 1:
self.a += 1
elif self.type_ == 2:
self.a -= 1
else:
raise ValueError(f'Unknown type: {self.type_}')
return self.a
In the above example, the target statement in if (statement)
syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.
model = bm.jit(OddEven(type_=1))
model()
Variable(value=DeviceArray([1.]), dtype=float32)
model = bm.jit(OddEven(type_=2))
model()
Variable(value=DeviceArray([-1.]), dtype=float32)
try:
model = bm.jit(OddEven(type_=3))
model()
except ValueError as e:
print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable
-based control statements#
However, if the statement
target in a if ... else ...
syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.
class OddEvenCauseError(bp.BrainPyObject):
def __init__(self):
super(OddEvenCauseError, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.rand < 0.5: self.a += 1
else: self.a -= 1
return self.a
wrong_model = bm.jit(OddEvenCauseError())
try:
wrong_model()
except Exception as e:
print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[1])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function <unknown> for eval_shape. This value became a tracer due to JAX operations on these lines:
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line D:\codes\projects\brainpy-chaoming0625\brainpy\_src\math\ndarray.py:233 (__lt__)
operation a:bool[1] = lt b c
from line D:\codes\projects\brainpy-chaoming0625\brainpy\_src\math\ndarray.py:233 (__lt__)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):
brainpy.math.where: return element-wise conditional comparison results.
brainpy.math.ifelse: Conditional statements of
if-else
, orif-elif-else
, … for a scalar-typed value.
brainpy.math.where
#
where(condition, x, y)
function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.
a = 1.
bm.where(a < 0, 0., 1.)
DeviceArray(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
Array(value=DeviceArray([0., 0., 0., 0., 0.]), dtype=float32)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
Array(value=DeviceArray([[0., 0., 0.],
[1., 1., 1.],
[0., 1., 1.]]),
dtype=float32)
For the above example, we can rewrite it by using where
syntax as:
class OddEvenWhere(bp.BrainPyObject):
def __init__(self):
super(OddEvenWhere, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.where(self.rand < 0.5, 1., -1.)
return self.a
model = bm.jit(OddEvenWhere())
model()
Variable(value=DeviceArray([-1.]), dtype=float32)
brainpy.math.ifelse
#
Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.
In its simplest case, brainpy.math.ifelse(condition, branches, operands)
is equivalent to:
def ifelse(condition, branches, operands):
true_fun, false_fun = branches
if condition:
return true_fun(operands)
else:
return false_fun(operands)
Based on this function, we can rewrite the above example by using cond
syntax as:
class OddEvenCond(bp.BrainPyObject):
def __init__(self):
super(OddEvenCond, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.ifelse(self.rand[0] < 0.5,
[lambda _: 1., lambda _: -1.])
return self.a
model = bm.jit(OddEvenCond())
model()
Variable(value=DeviceArray([-1.]), dtype=float32)
If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands)
can also help you accomplish this easily. Actually, multiple branching case is equivalent to:
def ifelse(conditions, branches, operands):
pred1, pred2, ... = conditions
func1, func2, ..., funcN = branches
if pred1:
return func1(operands)
elif pred2:
return func2(operands)
...
else:
return funcN(operands)
For example, if you have the following code:
def f(a):
if a > 10:
return 1.
elif a > 5:
return 2.
elif a > 0:
return 3.
elif a > -5:
return 4.
else:
return 5.
It can be expressed as:
def f(a):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[1., 2., 3., 4., 5.])
f(11.)
DeviceArray(1., dtype=float32, weak_type=True)
f(6.)
DeviceArray(2., dtype=float32, weak_type=True)
f(1.)
DeviceArray(3., dtype=float32, weak_type=True)
f(-4.)
DeviceArray(4., dtype=float32, weak_type=True)
f(-6.)
DeviceArray(5., dtype=float32, weak_type=True)
A more complex example is:
def f2(a, x):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[lambda x: x*2,
2.,
lambda x: x**2 -1,
lambda x: x - 4.,
5.],
operands=x)
f2(11, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(6, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(1, 1.)
DeviceArray(0., dtype=float32, weak_type=True)
f2(-4, 1.)
DeviceArray(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
DeviceArray(5., dtype=float32, weak_type=True)
If instances of brainpy.math.Variable
are used in branching functions,
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
bm.ifelse(True, [true_f, false_f])
bm.ifelse(False, [true_f, false_f])
print('a:', a)
print('b:', b)
a: Variable(value=DeviceArray([1., 1.]), dtype=float32)
b: Variable(value=DeviceArray([0., 0.]), dtype=float32)
2. Repetition#
A repetition statement is used to repeat a group(block) of programming instructions.
In Python, we generally have two loops/repetitive statements:
for loop: Execute a set of statements once for each item in a sequence.
while loop: Execute a block of statements repeatedly until a given condition is satisfied.
Pythonic loop syntax#
Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,
class LoopSimple(bp.BrainPyObject):
def __init__(self):
super(LoopSimple, self).__init__()
rng = bm.random.RandomState(123)
self.seq = bm.Variable(rng.random(1000))
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
for s in self.seq:
self.res += s
return self.res.value
import time
def measure_time(f, return_res=False, verbose=True):
t0 = time.time()
r = f()
t1 = time.time()
if verbose:
print(f'Result: {r}, Time: {t1 - t0}')
return r if return_res else None
model = bm.jit(LoopSimple())
# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 1.2315161228179932
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0
When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.
JAX has provided several important loop syntax, including:
BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable
. Specifically, they are:
brainpy.math.for_loop
brainpy.math.while_loop
In this section, we only talk about how to use our provided loop functions.
brainpy.math.for_loop()
#
brainpy.math.for_loop()
is used to generate a for-loop function when you use Variable
.
Suppose that you are using several Variables
to implement your body function “body_fun”. Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, xs):
ys = []
for x in xs:
results = body_fun(x)
ys.append(results)
return ys
In BrainPy, you can define this logic using brainpy.math.for_loop()
:
import brainpy.math
hist_of_out_vars = brainpy.math.for_loop(body_fun, operands)
For the above example, we can rewrite it by using brainpy.math.for_loop
as:
class LoopStruct(bp.BrainPyObject):
def __init__(self):
super(LoopStruct, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
def add(s):
self.res += s
return self.res.value
return bm.for_loop(body_fun=add, operands=self.seq)
model = bm.jit(LoopStruct())
r = measure_time(model, verbose=False, return_res=True)
r.shape
(1000, 1)
In essence, body_fun
defines the one-step updating rule of how variables are updated. All returns of body_fun
will be gathered as the history values. operands
specified the inputs of the body_fun
. It will be looped over the fist axis.
brainpy.math.while_loop()
#
brainpy.math.while_loop()
is used to generate a while-loop function when you use Varible
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.while_loop()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which receives the old values at the latest step and returns the updated values at the current step:
while cond_fun(x):
x = body_fun(x)
Note the difference between brainpy.math.for_loop
and brainpy.math.while_loop
:
The returns of
brainpy.math.for_loop
are the values to be gathered as the history values. While the returns ofbrainpy.math.while_loop
should be the same shape and type with its inputs, because they are represented as the updated values.brainpy.math.for_loop
can receive anything without explicit requirements of returns. But,brainpy.math.while_loop
should return what it receives.
A concreate example of brainpy.math.while_loop
is as the follows:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f():
return i[0] < 10
def body_f():
i.value += 1.
counter.value += i
bm.while_loop(body_f, cond_f, operands=())
print(counter, i)
Variable(value=DeviceArray([55.]), dtype=float32) Variable(value=DeviceArray([10.]), dtype=float32)
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
Or, similarly,
i = bm.Variable(bm.zeros(1))
def cond_f(counter):
return i[0] < 10
def body_f(counter):
i.value += 1.
return counter + i[0]
bm.while_loop(body_f, cond_f, operands=(1., ))
(DeviceArray(56., dtype=float32),)