Variable and BrainPyObject#

Colab Open in Kaggle

@Chaoming Wang @Xiaoyu Chen

In BrainPy, the JIT compilation for class objects relies on brainpy.math.Variable. Moreover, brainpy.math.BrainPyObject is the container for wrapping Variable in a class object.

In this section, we are going to understand:

  • What is a Variable?

  • What is a BrainPyObject?

  • How to update a Variable?

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bp.__version__
'2.4.0'

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a JAX Array. It stores an array as its value. The data in a Variable can be changed during our object-oriented JIT compilation. If an array is labeled as a Variable, it means that it is a dynamical variable that changes over time.

Arrays that are not marked as Variables will be JIT compiled as static data. Modifications of these arrays will be invalid or cause an error.

  • Creating a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable. For example:

b1 = bm.random.random(5)
b1
Array(value=DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818  , 0.86386406],            dtype=float32), dtype=float32)
b2 = bm.Variable(b1)
b2
Variable(value=DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818  , 0.86386406],            dtype=float32), dtype=float32)
  • Accessing the value in a Variable

The data in a Variable can be obtained through .value.

b2.value
DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818  , 0.86386406],            dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
  • Supported operations on Variables

Variables support almost all the operations for a JAX array.

b2 + 1.
Array(value=DeviceArray([1.5946547, 1.5306349, 1.1628337, 1.857818 , 1.8638641]), dtype=float32)
b2 ** 2
Array(value=DeviceArray([0.35361418, 0.28157339, 0.02651481, 0.7358517 , 0.7462611 ],            dtype=float32), dtype=float32)
bm.floor(b2)
Array(value=DeviceArray([0., 0., 0., 0., 0.]), dtype=float32)

brainpy.math.BrainPyObject#

BrainPyObject can be viewed as a base class for wrapping all Variables in a class.

By using .vars() function, any Variable defined in this class can be easily pulled out.

# for example

hh = bp.neurons.HH(1)

hh.vars().keys()  # a HH model has 6 variables
dict_keys(['HH0.V', 'HH0.h', 'HH0.input', 'HH0.m', 'HH0.n', 'HH0.spike'])

Subtypes of Variable#

Customizing Variable types#

Sometimes, different variables can be served as the different roles in a computation. To distinguish them from each other, we can subclass brainpy.math.Variable to define any variable class you want.

class YourVar(bm.Variable):
    pass
class YourModel(bm.BrainPyObject):
    def __init__(self):
        super().__init__()

        self.a = bm.Variable(bm.ones(1))
        self.b = YourVar(bm.zeros(1))
model = YourModel()
model.vars().subset(YourVar)
{'YourModel0.b': YourVar(value=DeviceArray([0.]), dtype=float32)}

In BrainPy, we have provided many Variable subtypes, including brainpy.math.TrainVar, brainpy.math.Parameter, and brainpy.math.RandomState.

1. TrainVar#

brainpy.math.TrainVar is a trainable variable and a subclass of brainpy.math.Variable. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.

b = bm.random.rand(4)

b
Array(value=DeviceArray([0.27101624, 0.7630037 , 0.71727633, 0.73568165]), dtype=float32)
bm.TrainVar(b)
TrainVar(value=DeviceArray([0.27101624, 0.7630037 , 0.71727633, 0.73568165]), dtype=float32)

2. Parameter#

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subset method.

b = bm.random.rand(1)

b
Array(value=DeviceArray([0.00502753]), dtype=float32)
bm.Parameter(b)
Parameter(value=DeviceArray([0.00502753]), dtype=float32)

3. RandomState#

brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.

state = bm.random.RandomState(1234)

state
RandomState(key=([   0, 1234], dtype=uint32))
# perform a "random" sampling 
state.random(1)

state  # the value changed
RandomState(key=([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling 
state.sample(1)

state  # the value changed too
RandomState(key=([1076515368, 3893328283], dtype=uint32))

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.random module: DEFAULT.

bm.random.DEFAULT
RandomState(key=([1573595401,  117587871], dtype=uint32))

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState(key=([     0, 654321], dtype=uint32))

How to update Variable?#

In BrainPy, the transformations (like JIT) usually need to update variables or arrays in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.

For example, here we have a variable a.

a = bm.Variable(bm.zeros(5))

The ids of the variable and the data stored in the variable are:

id_of_a = id(a)
id_of_data = id(a.value)

assert id_of_a != id_of_data

print('id(a)       = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a)       =  2151307205504
id(a.value) =  2151397524800

In-place update (here we use [:]) does not change the pointer refered to the variable but changes its data:

a[:] = 1.

print('id(a)       = ', id(a))
print('id(a.value) = ', id(a.value))
id(a)       =  2151307205504
id(a.value) =  2151399078480
print('(id(a) == id_of_a)          =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a)          = True
(id(a.value) == id_of_data) = False

However, once you do not use in-place operators to assign data, the id that the variable a refers to will change. This will cause serious errors when using transformations in BrainPy.

a = 10.

print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) =  2151305428432
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.Array``. 

Here, we list several commonly used in-place operators.

v = bm.Variable(bm.arange(10))
old_id = id(v)

def check_no_change(new_v):
    assert id(new_v) == old_id, 'Variable has been changed.'

1. Indexing and slicing#

Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.

Indexing: v[i] = a or v[(1, 3)] = c (index multiple values)

v[0] = 1

check_no_change(v)

Slicing: v[i:j] = b

v[1: 2] = 1

check_no_change(v)

Slicing all values: v[:] = d, v[...] = e

v[:] = 0

check_no_change(v)
v[...] = bm.arange(10)

check_no_change(v)

2. Augmented assignment#

All augmented assignment are in-place operations, which include

  • add: +=

  • subtract: -=

  • divide: /=

  • multiply: *=

  • floor divide: //=

  • modulo: %=

  • power: **=

  • and: &=

  • or: |=

  • xor: ^=

  • left shift: <<=

  • right shift: >>=

v += 1

check_no_change(v)
v *= 2

check_no_change(v)
v |= bm.random.randint(0, 2, 10)

check_no_change(v)
v **= 2

check_no_change(v)
v >>= 2

check_no_change(v)

3. .value assignment#

Another way to in-place update a variable is to assign new data to .value. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.

v.value = bm.arange(10)

check_no_change(v)
try:
    v.value = bm.asarray(1.)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
    v.value = bm.random.random(10)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

4. .update() method#

Actually, the .value assignment is the same operation as the .update() method. Users who want a safe assignment can choose this method too.

v.update(bm.random.randint(0, 20, size=10))
try:
    v.update(bm.asarray(1.))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
    v.update(bm.random.random(10))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.