# brainpy.math.Variable#

In BrainPy, the JIT compilation for class objects relies on brainpy.math.Variable. In this section, we are going to understand:

• What is a Variable?

• What are the subtypes of Variable?

• How to update a Variable?

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__

'2.3.1'


## brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a JAX Array. It stores an array as its value. The data in a Variable can be changed during our object-oriented JIT compilation. If an array is labeled as a Variable, it means that it is a dynamical variable that changes over time.

Arrays that are not marked as Variables will be JIT compiled as static data. Modifications of these arrays will be invalid or cause an error.

• Creating a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable. For example:

b1 = bm.random.random(5)
b1

Array([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)

b2 = bm.Variable(b1)
b2

Variable([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)

• Accessing the value in a Variable

The data in a Variable can be obtained through .value.

b2.value

Array([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)

(b2.value == b1).all()

Array(True, dtype=bool)

• Supported operations on Variables

Variables support almost all the operations for a JAX array.

b2 + 1.

Array([1.7049059, 1.9825947, 1.79977  , 1.2186428, 1.7095991], dtype=float32)

b2 ** 2

Array([0.49689227, 0.9654924 , 0.63963205, 0.04780469, 0.5035309 ],      dtype=float32)

bm.floor(b2)

Array([0., 0., 0., 0., 0.], dtype=float32)


## Subtypes of Variable#

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar, brainpy.math.Parameter, and brainpy.math.RandomState. Subtypes can also be customized and extended by users.

### 1. TrainVar#

brainpy.math.TrainVar is a trainable variable and a subclass of brainpy.math.Variable. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.

b = bm.random.rand(4)

b

Array([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)

bm.TrainVar(b)

TrainVar([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)


### 2. Parameter#

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subset method.

b = bm.random.rand(1)

b

Array([0.47972953], dtype=float32)

bm.Parameter(b)

Parameter([0.47972953], dtype=float32)


### 3. RandomState#

brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.

state = bm.random.RandomState(1234)

state

RandomState(key=([   0, 1234], dtype=uint32))

# perform a "random" sampling
state.random(1)

state  # the value changed

RandomState(key=([2113592192, 1902136347], dtype=uint32))

# perform a "sample" sampling
state.sample(1)

state  # the value changed too

RandomState(key=([1076515368, 3893328283], dtype=uint32))


Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()

Array([3028232624,  826525938], dtype=uint32)


It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)

Array([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)

state.split_keys(5)

Array([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)


There is a default RandomState in brainpy.math.random module: DEFAULT.

bm.random.DEFAULT

RandomState(key=([1682297581, 3751629511], dtype=uint32))


The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT

RandomState(key=([     0, 654321], dtype=uint32))


## In-place updating#

In BrainPy, the transformations (like JIT) usually need to update variables or arrays in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.

For example, here we have a variable a.

a = bm.Variable(bm.zeros(5))


The ids of the variable and the data stored in the variable are:

id_of_a = id(a)
id_of_data = id(a.value)

assert id_of_a != id_of_data

print('id(a)       = ', id_of_a)
print('id(a.value) = ', id_of_data)

id(a)       =  2781947736704
id(a.value) =  2781965742144


In-place update (here we use [:]) does not change the pointer refered to the variable but changes its data:

a[:] = 1.

print('id(a)       = ', id(a))
print('id(a.value) = ', id(a.value))

id(a)       =  2781947736704
id(a.value) =  2781965752128

print('(id(a) == id_of_a)          =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)

(id(a) == id_of_a)          = True
(id(a.value) == id_of_data) = False


However, once you do not use in-place operators to assign data, the id that the variable a refers to will change. This will cause serious errors when using transformations in BrainPy.

a = 10.

print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)

id(a) =  2781946941520
(id(a) == id_of_a) = False

The following in-place operators are not limited to brainpy.math.Variable and its subclasses. They can also apply to brainpy.math.Array.


Here, we list several commonly used in-place operators.

v = bm.Variable(bm.arange(10))

old_id = id(v)

def check_no_change(new_v):
assert id(new_v) == old_id, 'Variable has been changed.'


### 1. Indexing and slicing#

Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.

Indexing: v[i] = a or v[(1, 3)] = c (index multiple values)

v[0] = 1

check_no_change(v)


Slicing: v[i:j] = b

v[1: 2] = 1

check_no_change(v)


Slicing all values: v[:] = d, v[...] = e

v[:] = 0

check_no_change(v)

v[...] = bm.arange(10)

check_no_change(v)


### 2. Augmented assignment#

All augmented assignment are in-place operations, which include

• add: +=

• subtract: -=

• divide: /=

• multiply: *=

• floor divide: //=

• modulo: %=

• power: **=

• and: &=

• or: |=

• xor: ^=

• left shift: <<=

• right shift: >>=

v += 1

check_no_change(v)

v *= 2

check_no_change(v)

v |= bm.random.randint(0, 2, 10)

check_no_change(v)

v **= 2

check_no_change(v)

v >>= 2

check_no_change(v)


### 3. .value assignment#

Another way to in-place update a variable is to assign new data to .value. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.

v.value = bm.arange(10)

check_no_change(v)

try:
v.value = bm.asarray(1.)
except Exception as e:
print(type(e), e)

<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.

try:
v.value = bm.random.random(10)
except Exception as e:
print(type(e), e)

<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.


### 4. .update() method#

Actually, the .value assignment is the same operation as the .update() method. Users who want a safe assignment can choose this method too.

v.update(bm.random.randint(0, 20, size=10))

try:
v.update(bm.asarray(1.))
except Exception as e:
print(type(e), e)

<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.

try:
v.update(bm.random.random(10))
except Exception as e:
print(type(e), e)

<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.