Control Flows

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about how to build structured control flows with the BrainPy data structure JaxArray. These control flows include

  • the for loop syntax,

  • the while loop syntax,

  • and the condition syntax.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

In JAX, the control flow syntax must be defined as structured control flows. the JaxArray in BrainPy provides an easier syntax to make control flows.

Note

All the control flow syntax below is not re-implementations of JAX’s API for control flows. We only gurantee the following APIs are useful and intuitive when you use brainpy.math.JaxArray.

brainpy.math.make_loop()

brainpy.math.make_loop() is used to generate a for-loop function when you use JaxArray.

Suppose that you are using several JaxArrays (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as


def for_loop_function(body_fun, dyn_vars, out_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
    results = body_fun(x)
    ys.append([out_vars, results])
  return ys

In BrainPy, you can define this logic using brainpy.math.make_loop():


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)

hist_of_out_vars = loop_fun(xs)

Or,


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)

hist_of_out_vars, hist_of_return_vars = loop_fun(xs)

Let’s implement a recurrent network to illustrate how to use this function.

class RNN(bp.DynamicalSystem):
  def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
    super(RNN, self).__init__(**kwargs)

    # parameters
    self.n_in = n_in
    self.n_h = n_h
    self.n_out = n_out
    self.n_batch = n_batch
    self.g = g

    # weights
    self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
    self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
    self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
    self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
    self.b_ro = bm.TrainVar(bm.zeros((n_out,)))

    # variables
    self.h = bm.Variable(bm.random.random((n_batch, n_h)))

    # function
    self.predict = bm.make_loop(self.cell,
                                dyn_vars=self.vars(),
                                out_vars=self.h,
                                has_return=True)

  def cell(self, x):
    self.h.value = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
    o = self.h @ self.w_ro + self.b_ro
    return o


rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)

In the above RNN model, we define a body function RNN.cell for later for-loop over input values. The loop function is defined as self.predict with bm.make_loop(). We care about the history values of “self.h” and the readout value “o”, so we set out_vars=self.h and has_return=True.

xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape  # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape  # the shape should be (num_time, ) + o.shape
(100, 5, 3)

If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs), where “xs” can be a JaxArray or a list/tuple/dict of JaxArray. For example:

a = bm.zeros(10)

def body(x):
    x1, x2 = x  # "x" is a tuple/list of JaxArray
    a.value += (x1 + x2)

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))
a = bm.zeros(10)

def body(x):  # "x" is a dict of JaxArray
    a.value += x['a'] + x['b']

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))

dyn_vars, out_vars, xs and the body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_varsand body function returns. If has_return=True, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns). If no values are interested, please set out_vars=None, and the loop function only returns hist_of_out_vars.

brainpy.math.make_while()

brainpy.math.make_while() is used to generate a while-loop function when you use JaxArray. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.make_while() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which does not support returned values:


while cond_fun(x):
    body_fun(x)

where x is the external input that is not iterated. All the iterated variables should be marked as JaxArray. All JaxArrays used in cond_fun and body_fun should be declared as dyn_vars variables.

Let’s look an example:

i = bm.zeros(1)
counter = bm.zeros(1)

def cond_f(x): 
    return i[0] < 10

def body_f(x):
    i.value += 1.
    counter.value += i

loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])

In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i and counter.

loop()
counter
JaxArray(DeviceArray([55.], dtype=float32))
i
JaxArray(DeviceArray([10.], dtype=float32))

brainpy.math.make_cond()

brainpy.math.make_cond() is used to generate a condition function you use JaxArray. It supports the following conditional logic:


if True:
    true statements 
else: 
    false statements

When using brainpy.math.make_cond() , true statements should be wrapped as a true_fun function which implements logics under true assertion, and false statements should be wrapped as a false_fun function which implements logics under false assertion. Neither function supports returning values.


if True:
    true_fun(x)
else:
    false_fun(x)

All the JaxArrays used in true_fun and false_fun should be declared in the dyn_vars argument. x is used to receive the external input value.

Let’s make a try:

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += 1

def false_f(x): b.value -= 1

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])

Here, we have two tensors. If true, tensor a is added by 1; if false, tensor b is subtracted by 1.

cond(pred=True)

a, b
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(True)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([0., 0.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([-1., -1.], dtype=float32)))

Or, we define a conditional case which depends on the external input.

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += x

def false_f(x): b.value -= x

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False, 5.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([-4., -4.], dtype=float32)))