Control Flows

In this section, we are going to talk about how to build structured control flows in ‘jax’ backend. These control flows include

  • for loop syntax,

  • while loop syntax,

  • and condition syntax.

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

In JAX, the control flow syntaxes are not easy to use. Users must transform the intuitive Python control flows into structured control flows.

Based on JaxArray provided in BrainPy, we try to present a better syntax to make control flows.

make_loop()

brainpy.math.jax.make_loop() is used to generate a for-loop function when you are using JaxArray.

Let’s image your requirement: you are using several JaxArray (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes, your body function return something, and you also want to gather the return values. With Python syntax, your requirement is equivalent to


def for_loop_function(body_fun, dyn_vars, out_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' and 'out_vars' 
    # are updated in 'body_fun()'
    results = body_fun(x)
    ys.append([out_vars, results])
  return ys

In BrainPy, using brainpy.math.jax.make_loop() you can define this logic like:


loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=False)

hist_of_out_vars = loop_fun(xs)

Or,


loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=True)

hist_of_out_vars, hist_of_return_vars = loop_fun(xs)

Let’s implement a recurrent network to illustrate how to use this function.

class RNN(bp.DynamicalSystem):
  def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
    super(RNN, self).__init__(**kwargs)

    # parameters
    self.n_in = n_in
    self.n_h = n_h
    self.n_out = n_out
    self.n_batch = n_batch
    self.g = g

    # weights
    self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
    self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
    self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
    self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
    self.b_ro = bm.TrainVar(bm.zeros((n_out,)))

    # variables
    self.h = bm.Variable(bm.random.random((n_batch, n_h)))

    # function
    self.predict = bm.make_loop(self.cell,
                                dyn_vars=self.vars(),
                                out_vars=self.h,
                                has_return=True)

  def cell(self, x):
    self.h[:] = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
    o = self.h @ self.w_ro + self.b_ro
    return o


rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)

In the above RNN model, we define a body function RNN.cell for later for-loop over input values. The loop function is defined as self.predict with bm.make_loop(). We care about the history values of “self.h” and the readout value “o”, so we set out_vars = self.h and has_return=True.

xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape  # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape  # the shape should be (num_time, ) + o.shape
(100, 5, 3)

If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs), where “xs” can be a JaxArray, list/tuple/dict of JaxArray. For examples:

a = bm.zeros(10)

def body(x):
    x1, x2 = x  # "x" is a tuple/list of JaxArray
    a.value += (x1 + x2)

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))
a = bm.zeros(10)

def body(x):  # "x" is a dict of JaxArray
    a.value += x['a'] + x['b']

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))

dyn_vars, out_vars, xs and body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_varsand body function returns. If has_return=True, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns). If no values are interested, please set out_vars=None, and the loop function only returns hist_of_out_vars.

make_while()

brainpy.math.jax.make_while() is used to generate a while-loop function when you are using JaxArray. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.jax.make_while() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which does not support return values:


while cond_fun(x):
    body_fun(x)

where x is the external input which is not iterated. All the iterated variables should be marked as JaxArray. All JaxArray used in cond_fun and body_fun should be declared in a dyn_vars variable.

Let’s look an example:

i = bm.zeros(1)
counter = bm.zeros(1)

def cond_f(x): 
    return i[0] < 10

def body_f(x):
    i.value += 1.
    counter.value += i

loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])

In the above, we try to implement a sum from 0 to 10. We use two JaxArray i and counter.

loop()
counter
JaxArray(DeviceArray([55.], dtype=float32))
i
JaxArray(DeviceArray([10.], dtype=float32))

make_cond()

brainpy.math.jax.make_cond() is used to generate a condition function when you are using JaxArray. It supports the following condition logic:


if True:
    true statements 
else: 
    false statements

When using brainpy.math.jax.make_cond() , true statements should be wrapped as a true_fun function which implements logics under true assert (no return), and false statements should be wrapped as a false_fun function which implements logics under false assert (also does not support return values):


if True:
    true_fun(x)
else:
    false_fun(x)

All the JaxArray used in true_fun and false_fun should be declared in the dyn_vars argument. x is also used to receive the external input value.

Let’s make a try:

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += 1

def false_f(x): b.value -= 1

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])

Here, we have two tensors. If true, tensor a add 1; if false, tensor b subtract 1.

cond(pred=True)

a, b
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(True)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([0., 0.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([-1., -1.], dtype=float32)))

Or, we define a conditional case which depends on the external input.

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += x

def false_f(x): b.value -= x

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False, 5.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([-4., -4.], dtype=float32)))