Gotchas of BrainPy Transformations#

Colab Open in Kaggle

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.4.2'

BrainPy provides a novel concept for object-oriented transformations based brainpy.math.Variable. However, this kind of transformations faces several gotchas:

1. Variable that will be changed cannot be functional arguments#

This will not work too for the new oo transformations.

@bm.jit
def f(a, b):
  a.value = b

a = bm.Variable(bm.ones(1))
b = bm.Variable(bm.ones(1) * 10)
f(a, b)

try:
  assert bm.allclose(a, b)
  print('a equals to b.')
except:
  print('a is not equal to b.')
a is not equal to b.
a
Variable(value=Array([1.]), dtype=float32)

All Variables should be used in a global context.

Instead, this works:

@bm.jit
def f(b):
  a.value = b

a = bm.Variable(bm.ones(1))
b = bm.Variable(bm.ones(1) * 10)
f(b)

a
Variable(value=Array([10.]), dtype=float32)

2. Functions to be transformed are called twice#

The core mechanism of any brainpy transformation is that it firsts calls the function to automatically find all Variables used in the model, and then it calls the function again to compile the model with the found Variables.

Therefore, any function that the user create will be called more than twice.

@bm.jit
def f(inp):
  print('calling f ...')
  return inp

@bm.jit
def g(inp):
  print('calling g ...')
  return f(inp)

Taking the above function as an example, when we use this function, we will get:

g(1.)
calling g ...
calling f ...
calling g ...
calling f ...
Array(1., dtype=float32, weak_type=True)

It sequentially calls f and g to infer all dynamical variables (instances of Variable) used in these two functions. So we got first two lines of calling g ... and calling f.

Then, it compiles the two functions, so that we got next two lines of calling g ... and calling f.

Note that this property may get what are not correct in the Python level variables. For example, when we use a global variable to record the number of times the function called:

num = [0]

@bm.jit
def h(inp):
  num[0] += 1
  return inp
h(1.)
Array(1., dtype=float32, weak_type=True)

Although we called the function h once, we got the number of 2.

num
[2]