Control Flows
Contents
Control Flows#
Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.
Python has two types of control structures:
Selection: used for decisions and branching.
Repetition: used for looping, i.e., repeating a piece of code multiple times.
In this section, we are going to talk about how to build effective control flows with BrainPy and JAX.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
1. Selection#
In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:
if-else
nested if
if-elif-else
Non-Variable
-based control statements#
Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,
class OddEven(bp.Base):
def __init__(self, type_=1):
super(OddEven, self).__init__()
self.type_ = type_
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.type_ == 1:
self.a += 1
elif self.type_ == 2:
self.a -= 1
else:
raise ValueError(f'Unknown type: {self.type_}')
return self.a
In the above example, the target statement in if (statement)
syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.
model = bm.jit(OddEven(type_=1))
model()
Variable([1.], dtype=float32)
model = bm.jit(OddEven(type_=2))
model()
Variable([-1.], dtype=float32)
try:
model = bm.jit(OddEven(type_=3))
model()
except ValueError as e:
print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable
-based control statements#
However, if the statement
target in a if ... else ...
syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.
class OddEvenCauseError(bp.Base):
def __init__(self):
super(OddEvenCauseError, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.rand < 0.5: self.a += 1
else: self.a -= 1
return self.a
wrong_model = bm.jit(OddEvenCauseError())
try:
wrong_model()
except Exception as e:
print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: This problem may be caused by several ways:
1. Your if-else conditional statement relies on instances of brainpy.math.Variable.
2. Your if-else conditional statement relies on functional arguments which do not set in "static_argnames" when applying JIT compilation. More details please see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
3. The static variables which set in the "static_argnames" are provided as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].
To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):
brainpy.math.where: return element-wise conditional comparison results.
brainpy.math.ifelse: Conditional statements of
if-else
, orif-elif-else
, … for a scalar-typed value.
brainpy.math.where
#
where(condition, x, y)
function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.
a = 1.
bm.where(a < 0, 0., 1.)
JaxArray(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
JaxArray([1., 0., 0., 1., 1.], dtype=float32, weak_type=True)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
JaxArray([[0., 0., 1.],
[1., 1., 0.],
[0., 0., 0.]], dtype=float32, weak_type=True)
For the above example, we can rewrite it by using where
syntax as:
class OddEvenWhere(bp.Base):
def __init__(self):
super(OddEvenWhere, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.where(self.rand < 0.5, 1., -1.)
return self.a
model = bm.jit(OddEvenWhere())
model()
Variable([-1.], dtype=float32)
brainpy.math.ifelse
#
Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.
In its simplest case, brainpy.math.ifelse(condition, branches, operands, dyn_vars=None)
is equivalent to:
def ifelse(condition, branches, operands, dyn_vars=None):
true_fun, false_fun = branches
if condition:
return true_fun(operands)
else:
return false_fun(operands)
Based on this function, we can rewrite the above example by using cond
syntax as:
class OddEvenCond(bp.Base):
def __init__(self):
super(OddEvenCond, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.ifelse(self.rand[0] < 0.5,
[lambda _: 1., lambda _: -1.])
return self.a
model = bm.jit(OddEvenCond())
model()
Variable([1.], dtype=float32)
If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands, dyn_vars=None)
can also help you accomplish this easily. Actually, multiple branching case is equivalent to:
def ifelse(conditions, branches, operands, dyn_vars=None):
pred1, pred2, ... = conditions
func1, func2, ..., funcN = branches
if pred1:
return func1(operands)
elif pred2:
return func2(operands)
...
else:
return funcN(operands)
For example, if you have the following code:
def f(a):
if a > 10:
return 1.
elif a > 5:
return 2.
elif a > 0:
return 3.
elif a > -5:
return 4.
else:
return 5.
It can be expressed as:
def f(a):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[1., 2., 3., 4., 5.])
f(11.)
DeviceArray(1., dtype=float32, weak_type=True)
f(6.)
DeviceArray(2., dtype=float32, weak_type=True)
f(1.)
DeviceArray(3., dtype=float32, weak_type=True)
f(-4.)
DeviceArray(4., dtype=float32, weak_type=True)
f(-6.)
DeviceArray(5., dtype=float32, weak_type=True)
A more complex example is:
def f2(a, x):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[lambda x: x*2,
2.,
lambda x: x**2 -1,
lambda x: x - 4.,
5.],
operands=x)
f2(11, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(6, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(1, 1.)
DeviceArray(0., dtype=float32, weak_type=True)
f2(-4, 1.)
DeviceArray(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
DeviceArray(5., dtype=float32, weak_type=True)
If instances of brainpy.math.Variable
are used in branching functions, you can declare them in the dyn_vars
argument.
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])
print('a:', a)
print('b:', b)
a: Variable([1., 1.], dtype=float32)
b: Variable([0., 0.], dtype=float32)
2. Repetition#
A repetition statement is used to repeat a group(block) of programming instructions.
In Python, we generally have two loops/repetitive statements:
for loop: Execute a set of statements once for each item in a sequence.
while loop: Execute a block of statements repeatedly until a given condition is satisfied.
Pythonic loop syntax#
Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,
class LoopSimple(bp.Base):
def __init__(self):
super(LoopSimple, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
for s in self.seq:
self.res += s
return self.res.value
import time
def measure_time(f):
t0 = time.time()
r = f()
t1 = time.time()
print(f'Result: {r}, Time: {t1 - t0}')
model = bm.jit(LoopSimple())
# First time will trigger compilation
measure_time(model)
Result: [501.74673], Time: 2.7157142162323
# Second running
measure_time(model)
Result: [1003.49347], Time: 0.0
When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.
JAX has provided several important loop syntax, including:
BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable
. Specifically, they are:
In this section, we only talk about how to use our provided loop functions.
brainpy.math.make_loop()
#
brainpy.math.make_loop()
is used to generate a for-loop function when you use Variable
.
Suppose that you are using several JaxArrays (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, dyn_vars, out_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
results = body_fun(x)
ys.append([out_vars, results])
return ys
In BrainPy, you can define this logic using brainpy.math.make_loop()
:
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)
hist_of_out_vars = loop_fun(xs)
Or,
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)
hist_of_out_vars, hist_of_return_vars = loop_fun(xs)
For the above example, we can rewrite it by using brainpy.math.make_loop
as:
class LoopStruct(bp.Base):
def __init__(self):
super(LoopStruct, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def add(s): self.res += s
self.loop = bm.make_loop(add, dyn_vars=[self.res])
def __call__(self):
self.loop(self.seq)
return self.res.value
model = bm.jit(LoopStruct())
# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 0.028011560440063477
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0
brainpy.math.make_while()
#
brainpy.math.make_while()
is used to generate a while-loop function when you use JaxArray
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.make_while()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which does not support returned values:
while cond_fun(x):
body_fun(x)
where x
is the external input that is not iterated. All the iterated variables should be marked as JaxArray
. All JaxArray
s used in cond_fun
and body_fun
should be declared as dyn_vars
variables.
Let’s look an example:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f(x):
return i[0] < 10
def body_f(x):
i.value += 1.
counter.value += i
loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)