Variables

In BrainPy, the JIT compilation for class objects relies on Variable. In this section, we are going to understand:

• what is Variable?

• the subtypes of Variable?

import brainpy as bp
import brainpy.math as bm


Variable

brainpy.math.Variable is a pointer refers to a tensor. It stores the value of the tensor. The concrete value in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed.

During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error.

• Create a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable, for example:

bm.use_backend('numpy')

a1 = bm.random.random(5)
a2 = bm.Variable(a1)

a1, a2

(array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]),
Variable([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]))

bm.use_backend('jax')

b1 = bm.random.random(5)
b2 = bm.Variable(b1)

b1, b2

(JaxArray(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)),
Variable(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)))

• Access the value in a Variable

The concrete value of a Variable can be obtained through .value.

a2.value

array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844])

(a2.value == a1).all()

True

b2.value

DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)

(b2.value == b1).all()

DeviceArray(True, dtype=bool)

• Supported operations on a Variable

A Variable support almost all the operations for a tensor. Actually, brainpy.math.Variable is a subclass of brainpy.math.ndarray.

isinstance(a2, bp.math.numpy.ndarray)

True

isinstance(b2, bp.math.jax.ndarray)

True

isinstance(b2, bp.math.jax.JaxArray)

True

# bp.math.jax.ndarray is an alias for bp.math.jax.JaxArray in 'jax' backend

bp.math.jax.ndarray is bp.math.jax.JaxArray

True


Note

In ‘jax’ backend, after performing any operation on a Variable, the resulting value will be a JaxArray (bp.math.jax.ndarray is an alias for bp.math.jax.JaxArray in ‘jax’ backend). This means that the Variable can only be used to refer to a value.

b2 + 1.

JaxArray(DeviceArray([1.7053047, 1.9984136, 1.815271 , 1.926391 , 1.84018  ], dtype=float32))

b2 ** 2

JaxArray(DeviceArray([0.4974548 , 0.9968296 , 0.66466683, 0.8582003 , 0.7059025 ],            dtype=float32))

bp.math.jax.floor(b2)

JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))

• Subtypes of Variable

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar and brainpy.math.Parameter. Subtypes can also be customized and extended by the user. We are going to talk about this.

TrainVar

brainpy.math.TrainVar is a trainable variable (a subclass of brainpy.math.Variable). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose.

bm.use_backend('numpy')

a = bm.random.rand(4)

a, bm.TrainVar(a)

(array([0.81515042, 0.40363449, 0.89924935, 0.29827197]),
TrainVar([0.81515042, 0.40363449, 0.89924935, 0.29827197]))

bm.use_backend('jax')

b = bm.random.rand(4)

b, bm.TrainVar(b)

(JaxArray(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)),
TrainVar(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)))


Parameter

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets method (please see Base class).

bm.use_backend('numpy')

a = bm.random.rand(1)

a, bm.Parameter(a)

(array([0.5776296]), Parameter([0.5776296]))

bm.use_backend('jax')

b = bm.random.rand(1)

b, bm.Parameter(b)

(JaxArray(DeviceArray([0.61128676], dtype=float32)),
Parameter(DeviceArray([0.61128676], dtype=float32)))


RandomState

In ‘jax’ backend, brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. This is because the RandomState in ‘jax’ backend must store the dynamically changed key information. Every time after a RandomState performs a random sampling, the “key” will change. For example,

bm.use_backend('jax')

state = bm.random.RandomState(seed=1234)

state

RandomState(DeviceArray([   0, 1234], dtype=uint32))

# perform a "random" sampling
state.random(1)

# the value changed
state

RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))

# perform a "sample" sampling
state.sample(1)

# the value changed too
state

RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))


Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()

DeviceArray([3028232624,  826525938], dtype=uint32)


It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)

DeviceArray([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)

state.split_keys(5)

DeviceArray([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)


There is a default RandomState in brainpy.math.jax.random module: DEFAULT.

bm.random.DEFAULT

RandomState(DeviceArray([2580684476, 2503630841], dtype=uint32))


The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT

RandomState(DeviceArray([     0, 654321], dtype=uint32))