Variables

In BrainPy, the JIT compilation for class objects relies on Variable. In this section, we are going to understand:

  • what is Variable?

  • the subtypes of Variable?

import brainpy as bp
import brainpy.math as bm

Variable

brainpy.math.Variable is a pointer refers to a tensor. It stores the value of the tensor. The concrete value in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed.

During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error.

  • Create a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable, for example:

bm.use_backend('numpy')

a1 = bm.random.random(5)
a2 = bm.Variable(a1)

a1, a2 
(array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]),
 Variable([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]))
bm.use_backend('jax')

b1 = bm.random.random(5)
b2 = bm.Variable(b1)

b1, b2
(JaxArray(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)),
 Variable(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)))
  • Access the value in a Variable

The concrete value of a Variable can be obtained through .value.

a2.value
array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844])
(a2.value == a1).all()
True
b2.value
DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
  • Supported operations on a Variable

A Variable support almost all the operations for a tensor. Actually, brainpy.math.Variable is a subclass of brainpy.math.ndarray.

isinstance(a2, bp.math.numpy.ndarray)
True
isinstance(b2, bp.math.jax.ndarray)
True
isinstance(b2, bp.math.jax.JaxArray)
True
# `bp.math.jax.ndarray` is an alias for `bp.math.jax.JaxArray` in 'jax' backend

bp.math.jax.ndarray is bp.math.jax.JaxArray
True

Note

In ‘jax’ backend, after performing any operation on a Variable, the resulting value will be a JaxArray (bp.math.jax.ndarray is an alias for bp.math.jax.JaxArray in ‘jax’ backend). This means that the Variable can only be used to refer to a value.

b2 + 1.
JaxArray(DeviceArray([1.7053047, 1.9984136, 1.815271 , 1.926391 , 1.84018  ], dtype=float32))
b2 ** 2
JaxArray(DeviceArray([0.4974548 , 0.9968296 , 0.66466683, 0.8582003 , 0.7059025 ],            dtype=float32))
bp.math.jax.floor(b2)
JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))
  • Subtypes of Variable

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar and brainpy.math.Parameter. Subtypes can also be customized and extended by the user. We are going to talk about this.

TrainVar

brainpy.math.TrainVar is a trainable variable (a subclass of brainpy.math.Variable). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose.

bm.use_backend('numpy')

a = bm.random.rand(4)

a, bm.TrainVar(a)
(array([0.81515042, 0.40363449, 0.89924935, 0.29827197]),
 TrainVar([0.81515042, 0.40363449, 0.89924935, 0.29827197]))
bm.use_backend('jax')

b = bm.random.rand(4)

b, bm.TrainVar(b)
(JaxArray(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)),
 TrainVar(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)))

Parameter

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets method (please see Base class).

bm.use_backend('numpy')

a = bm.random.rand(1)

a, bm.Parameter(a)
(array([0.5776296]), Parameter([0.5776296]))
bm.use_backend('jax')

b = bm.random.rand(1)

b, bm.Parameter(b)
(JaxArray(DeviceArray([0.61128676], dtype=float32)),
 Parameter(DeviceArray([0.61128676], dtype=float32)))

RandomState

In ‘jax’ backend, brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. This is because the RandomState in ‘jax’ backend must store the dynamically changed key information. Every time after a RandomState performs a random sampling, the “key” will change. For example,

bm.use_backend('jax')

state = bm.random.RandomState(seed=1234)

state
RandomState(DeviceArray([   0, 1234], dtype=uint32))
# perform a "random" sampling 
state.random(1)

# the value changed
state
RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling 
state.sample(1)

# the value changed too
state
RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.jax.random module: DEFAULT.

bm.random.DEFAULT
RandomState(DeviceArray([2580684476, 2503630841], dtype=uint32))

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState(DeviceArray([     0, 654321], dtype=uint32))