More Details for Variables#

@Chaoming Wang @Xiaoyu Chen

In BrainPy, the JIT compilation for class objects relies on Variables. In this section, we are going to understand:

  • What is a Variable?

  • What are the subtypes of Variable?

  • How to update a Variable?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a tensor. It stores a tensor as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.

Tensors that are not marked as Variables will be JIT compiled as static data. Modifications of these tensors will be invalid or cause an error.

  • Creating a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable. For example:

b1 = bm.random.random(5)
b1
JaxArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
b2 = bm.Variable(b1)
b2
Variable([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
  • Accessing the value in a Variable

The data in a Variable can be obtained through .value.

b2.value
DeviceArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
  • Supported operations on Variables

Variables support almost all the operations for tensors. Actually, brainpy.math.Variable is a subclass of brainpy.math.ndarray.

isinstance(b2, bm.ndarray)
True
isinstance(b2, bm.JaxArray)
True
# `bp.math.ndarray` is an alias for `bp.math.JaxArray` in 'jax' backend

bm.ndarray is bm.JaxArray
True

Note

After performing any operation on a Variable, the resulting value will be a JaxArray (brainpy.math.ndarray is an alias for brainpy.math.JaxArray). This means that the Variable can only be used to refer to a single value.

b2 + 1.
JaxArray([1.9116168, 1.6901083, 1.4392058, 1.1322064, 1.771458 ], dtype=float32)
b2 ** 2
JaxArray([0.8310452 , 0.47624946, 0.1929017 , 0.01747854, 0.5951475 ],            dtype=float32)
bm.floor(b2)
JaxArray([0., 0., 0., 0., 0.], dtype=float32)

Subtypes of Variable#

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar, brainpy.math.Parameter, and brainpy.math.RandomState. Subtypes can also be customized and extended by users.

1. TrainVar#

brainpy.math.TrainVar is a trainable variable and a subclass of brainpy.math.Variable. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.

b = bm.random.rand(4)

b
JaxArray([0.59062696, 0.618052  , 0.84173155, 0.34012556], dtype=float32)
bm.TrainVar(b)
TrainVar([0.59062696, 0.618052  , 0.84173155, 0.34012556], dtype=float32)

2. Parameter#

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets method (please see Base class).

b = bm.random.rand(1)

b
JaxArray([0.14782536], dtype=float32)
bm.Parameter(b)
Parameter([0.14782536], dtype=float32)

3. RandomState#

brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.

state = bm.random.RandomState(seed=1234)

state
RandomState([   0, 1234], dtype=uint32)
# perform a "random" sampling 
state.random(1)

state  # the value changed
RandomState([2113592192, 1902136347], dtype=uint32)
# perform a "sample" sampling 
state.sample(1)

state  # the value changed too
RandomState([1076515368, 3893328283], dtype=uint32)

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.random module: DEFAULT.

bm.random.DEFAULT
RandomState([601887926, 339370966], dtype=uint32)

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState([     0, 654321], dtype=uint32)

In-place updating#

In BrainPy, the transformations (like JIT) usually need to update variables or tensors in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.

For example, here we have a variable a.

a = bm.Variable(bm.zeros(5))

The ids of the variable and the data stored in the variable are:

id_of_a = id(a)
id_of_data = id(a.value)

assert id_of_a != id_of_data

print('id(a)       = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a)       =  2101001001088
id(a.value) =  2101018127136

In-place update (here we use [:]) does not change the pointer refered to the variable but changes its data:

a[:] = 1.

print('id(a)       = ', id(a))
print('id(a.value) = ', id(a.value))
id(a)       =  2101001001088
id(a.value) =  2101019514880
print('(id(a) == id_of_a)          =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a)          = True
(id(a.value) == id_of_data) = False

However, once you do not use in-place operators to assign data, the id that the variable a refers to will change. This will cause serious errors when using transformations in BrainPy.

a = 10.

print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) =  2101001187280
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.JaxArray``. 

Here, we list several commonly used in-place operators.

v = bm.Variable(bm.arange(10))
old_id = id(v)

def check_no_change(new_v):
    assert id(new_v) == old_id, 'Variable has been changed.'

1. Indexing and slicing#

Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.

Indexing: v[i] = a or v[(1, 3)] = c (index multiple values)

v[0] = 1

check_no_change(v)

Slicing: v[i:j] = b

v[1: 2] = 1

check_no_change(v)

Slicing all values: v[:] = d, v[...] = e

v[:] = 0

check_no_change(v)
v[...] = bm.arange(10)

check_no_change(v)

2. Augmented assignment#

All augmented assignment are in-place operations, which include

  • add: +=

  • subtract: -=

  • divide: /=

  • multiply: *=

  • floor divide: //=

  • modulo: %=

  • power: **=

  • and: &=

  • or: |=

  • xor: ^=

  • left shift: <<=

  • right shift: >>=

v += 1

check_no_change(v)
v *= 2

check_no_change(v)
v |= bm.random.randint(0, 2, 10)

check_no_change(v)
v **= 2

check_no_change(v)
v >>= 2

check_no_change(v)

3. .value assignment#

Another way to in-place update a variable is to assign new data to .value. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.

v.value = bm.arange(10)

check_no_change(v)
try:
    v.value = bm.asarray(1.)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
    v.value = bm.random.random(10)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

4. .update() method#

Actually, the .value assignment is the same operation as the .update() method. Users who want a safe assignment can choose this method too.

v.update(bm.random.randint(0, 20, size=10))
try:
    v.update(bm.asarray(1.))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
    v.update(bm.random.random(10))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.