Variable
and BrainPyObject
#
In BrainPy, the JIT compilation for class objects relies on brainpy.math.Variable
. Moreover, brainpy.math.BrainPyObject
is the container for wrapping Variable
in a class object.
In this section, we are going to understand:
What is a
Variable
?What is a
BrainPyObject
?How to update a
Variable
?
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.4.0'
brainpy.math.Variable
#
brainpy.math.Variable
is a pointer referring to a JAX Array. It stores an array as its value. The data in a Variable can be changed during our object-oriented JIT compilation. If an array is labeled as a Variable, it means that it is a dynamical variable that changes over time.
Arrays that are not marked as Variables will be JIT compiled as static data. Modifications of these arrays will be invalid or cause an error.
Creating a Variable
Passing a tensor into the brainpy.math.Variable
creates a Variable. For example:
b1 = bm.random.random(5)
b1
Array(value=DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818 , 0.86386406], dtype=float32), dtype=float32)
b2 = bm.Variable(b1)
b2
Variable(value=DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818 , 0.86386406], dtype=float32), dtype=float32)
Accessing the value in a Variable
The data in a Variable can be obtained through .value
.
b2.value
DeviceArray([0.5946547 , 0.5306349 , 0.16283369, 0.857818 , 0.86386406], dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
Supported operations on Variables
Variables support almost all the operations for a JAX array.
b2 + 1.
Array(value=DeviceArray([1.5946547, 1.5306349, 1.1628337, 1.857818 , 1.8638641]), dtype=float32)
b2 ** 2
Array(value=DeviceArray([0.35361418, 0.28157339, 0.02651481, 0.7358517 , 0.7462611 ], dtype=float32), dtype=float32)
bm.floor(b2)
Array(value=DeviceArray([0., 0., 0., 0., 0.]), dtype=float32)
brainpy.math.BrainPyObject
#
BrainPyObject
can be viewed as a base class for wrapping all Variable
s in a class.
By using .vars()
function, any Variable
defined in this class can be easily pulled out.
# for example
hh = bp.neurons.HH(1)
hh.vars().keys() # a HH model has 6 variables
dict_keys(['HH0.V', 'HH0.h', 'HH0.input', 'HH0.m', 'HH0.n', 'HH0.spike'])
Subtypes of Variable
#
Customizing Variable
types#
Sometimes, different variables can be served as the different roles in a computation. To distinguish them from each other, we can subclass brainpy.math.Variable
to define any variable class you want.
class YourVar(bm.Variable):
pass
class YourModel(bm.BrainPyObject):
def __init__(self):
super().__init__()
self.a = bm.Variable(bm.ones(1))
self.b = YourVar(bm.zeros(1))
model = YourModel()
model.vars().subset(YourVar)
{'YourModel0.b': YourVar(value=DeviceArray([0.]), dtype=float32)}
In BrainPy, we have provided many Variable
subtypes, including brainpy.math.TrainVar
, brainpy.math.Parameter
, and brainpy.math.RandomState
.
1. TrainVar#
brainpy.math.TrainVar
is a trainable variable and a subclass of brainpy.math.Variable
. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.
b = bm.random.rand(4)
b
Array(value=DeviceArray([0.27101624, 0.7630037 , 0.71727633, 0.73568165]), dtype=float32)
bm.TrainVar(b)
TrainVar(value=DeviceArray([0.27101624, 0.7630037 , 0.71727633, 0.73568165]), dtype=float32)
2. Parameter#
brainpy.math.Parameter
is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable
. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subset
method.
b = bm.random.rand(1)
b
Array(value=DeviceArray([0.00502753]), dtype=float32)
bm.Parameter(b)
Parameter(value=DeviceArray([0.00502753]), dtype=float32)
3. RandomState#
brainpy.math.random.RandomState
is also a subclass of brainpy.math.Variable
. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.
state = bm.random.RandomState(1234)
state
RandomState(key=([ 0, 1234], dtype=uint32))
# perform a "random" sampling
state.random(1)
state # the value changed
RandomState(key=([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling
state.sample(1)
state # the value changed too
RandomState(key=([1076515368, 3893328283], dtype=uint32))
Every instance of RandomState can create a new seed from the current seed with .split_key()
.
state.split_key()
DeviceArray([3028232624, 826525938], dtype=uint32)
It can also create multiple seeds from the current seed with .split_keys(n)
. This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.
state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)
There is a default RandomState in brainpy.math.random
module: DEFAULT
.
bm.random.DEFAULT
RandomState(key=([1573595401, 117587871], dtype=uint32))
The inherent random methods like randint()
, rand()
, shuffle()
, etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed()
method.
bm.random.seed(654321)
bm.random.DEFAULT
RandomState(key=([ 0, 654321], dtype=uint32))
How to update Variable
?#
In BrainPy, the transformations (like JIT) usually need to update variables or arrays in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.
For example, here we have a variable a
.
a = bm.Variable(bm.zeros(5))
The ids of the variable and the data stored in the variable are:
id_of_a = id(a)
id_of_data = id(a.value)
assert id_of_a != id_of_data
print('id(a) = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a) = 2151307205504
id(a.value) = 2151397524800
In-place update (here we use [:]
) does not change the pointer refered to the variable but changes its data:
a[:] = 1.
print('id(a) = ', id(a))
print('id(a.value) = ', id(a.value))
id(a) = 2151307205504
id(a.value) = 2151399078480
print('(id(a) == id_of_a) =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a) = True
(id(a.value) == id_of_data) = False
However, once you do not use in-place operators to assign data, the id that the variable a
refers to will change. This will cause serious errors when using transformations in BrainPy.
a = 10.
print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) = 2151305428432
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.Array``.
Here, we list several commonly used in-place operators.
v = bm.Variable(bm.arange(10))
old_id = id(v)
def check_no_change(new_v):
assert id(new_v) == old_id, 'Variable has been changed.'
1. Indexing and slicing#
Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.
Indexing: v[i] = a
or v[(1, 3)] = c
(index multiple values)
v[0] = 1
check_no_change(v)
Slicing: v[i:j] = b
v[1: 2] = 1
check_no_change(v)
Slicing all values: v[:] = d
, v[...] = e
v[:] = 0
check_no_change(v)
v[...] = bm.arange(10)
check_no_change(v)
2. Augmented assignment#
All augmented assignment are in-place operations, which include
add:
+=
subtract:
-=
divide:
/=
multiply:
*=
floor divide:
//=
modulo:
%=
power:
**=
and:
&=
or:
|=
xor:
^=
left shift:
<<=
right shift:
>>=
v += 1
check_no_change(v)
v *= 2
check_no_change(v)
v |= bm.random.randint(0, 2, 10)
check_no_change(v)
v **= 2
check_no_change(v)
v >>= 2
check_no_change(v)
3. .value
assignment#
Another way to in-place update a variable is to assign new data to .value
. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.
v.value = bm.arange(10)
check_no_change(v)
try:
v.value = bm.asarray(1.)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
v.value = bm.random.random(10)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
4. .update()
method#
Actually, the .value
assignment is the same operation as the .update()
method. Users who want a safe assignment can choose this method too.
v.update(bm.random.randint(0, 20, size=10))
try:
v.update(bm.asarray(1.))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
v.update(bm.random.random(10))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.