Control Flows
Contents
Control Flows#
In this section, we are going to talk about how to build structured control flows with the BrainPy data structure JaxArray
. These control flows include
the for loop syntax,
the while loop syntax,
and the condition syntax.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
In JAX, the control flow syntax must be defined as structured control flows. the JaxArray
in BrainPy provides an easier syntax to make control flows.
Note
All the control flow syntax below is not re-implementations of JAX’s API for control flows. We only gurantee the following APIs are useful and intuitive when you use brainpy.math.JaxArray
.
brainpy.math.make_loop()
#
brainpy.math.make_loop()
is used to generate a for-loop function when you use JaxArray
.
Suppose that you are using several JaxArrays (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, dyn_vars, out_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
results = body_fun(x)
ys.append([out_vars, results])
return ys
In BrainPy, you can define this logic using brainpy.math.make_loop()
:
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)
hist_of_out_vars = loop_fun(xs)
Or,
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)
hist_of_out_vars, hist_of_return_vars = loop_fun(xs)
Let’s implement a recurrent network to illustrate how to use this function.
class RNN(bp.dyn.DynamicalSystem):
def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
super(RNN, self).__init__(**kwargs)
# parameters
self.n_in = n_in
self.n_h = n_h
self.n_out = n_out
self.n_batch = n_batch
self.g = g
# weights
self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
self.b_ro = bm.TrainVar(bm.zeros((n_out,)))
# variables
self.h = bm.Variable(bm.random.random((n_batch, n_h)))
# function
self.predict = bm.make_loop(self.cell,
dyn_vars=self.vars(),
out_vars=self.h,
has_return=True)
def cell(self, x):
self.h.value = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
o = self.h @ self.w_ro + self.b_ro
return o
rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)
In the above RNN
model, we define a body function RNN.cell
for later for-loop over input values. The loop function is defined as self.predict
with bm.make_loop()
. We care about the history values of “self.h” and the readout value “o”, so we set out_vars=self.h
and has_return=True
.
xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape # the shape should be (num_time, ) + o.shape
(100, 5, 3)
If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs)
, where “xs” can be a JaxArray or a list/tuple/dict of JaxArray. For example:
a = bm.Variable(bm.zeros(10))
def body(x):
x1, x2 = x # "x" is a tuple/list of JaxArray
a.value += (x1 + x2)
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
Variable([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32)
a = bm.Variable(bm.zeros(10))
def body(x): # "x" is a dict of JaxArray
a.value += x['a'] + x['b']
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
Variable([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32)
dyn_vars
, out_vars
, xs
and the body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_vars
and body function returns. If has_return=True
, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns)
. If no values are interested, please set out_vars=None
, and the loop function only returns hist_of_out_vars
.
brainpy.math.make_while()
#
brainpy.math.make_while()
is used to generate a while-loop function when you use JaxArray
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.make_while()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which does not support returned values:
while cond_fun(x):
body_fun(x)
where x
is the external input that is not iterated. All the iterated variables should be marked as JaxArray
. All JaxArray
s used in cond_fun
and body_fun
should be declared as dyn_vars
variables.
Let’s look an example:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f(x):
return i[0] < 10
def body_f(x):
i.value += 1.
counter.value += i
loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)
brainpy.math.make_cond()
#
brainpy.math.make_cond()
is used to generate a condition function you use JaxArray
. It supports the following conditional logic:
if True:
true statements
else:
false statements
When using brainpy.math.make_cond()
, true statements should be wrapped as a true_fun
function which implements logics under true assertion, and false statements should be wrapped as a false_fun
function which implements logics under false assertion. Neither function supports returning values.
if True:
true_fun(x)
else:
false_fun(x)
All the JaxArray
s used in true_fun
and false_fun
should be declared in the dyn_vars
argument. x
is used to receive the external input value.
Let’s make a try:
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
Here, we have two tensors. If true, tensor a
is added by 1; if false, tensor b
is subtracted by 1.
cond(pred=True)
a, b
(Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(True)
a, b
(Variable([2., 2.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False)
a, b
(Variable([2., 2.], dtype=float32), Variable([0., 0.], dtype=float32))
cond(False)
a, b
(Variable([2., 2.], dtype=float32), Variable([-1., -1.], dtype=float32))
Or, we define a conditional case which depends on the external input.
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += x
def false_f(x): b.value -= x
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)
a, b
(Variable([10., 10.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False, 5.)
a, b
(Variable([10., 10.], dtype=float32), Variable([-4., -4.], dtype=float32))