Variables
In BrainPy, the JIT compilation for class objects relies on Variable. In this section, we are going to understand:
what is
Variable
?the subtypes of
Variable
?
import brainpy as bp
import brainpy.math as bm
Variable
brainpy.math.Variable
is a pointer refers to a tensor. It stores the value of the tensor. The concrete value in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed.
During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error.
Create a Variable
Passing a tensor into the brainpy.math.Variable
creates a Variable, for example:
bm.use_backend('numpy')
a1 = bm.random.random(5)
a2 = bm.Variable(a1)
a1, a2
(array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]),
Variable([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]))
bm.use_backend('jax')
b1 = bm.random.random(5)
b2 = bm.Variable(b1)
b1, b2
(JaxArray(DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)),
Variable(DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)))
Access the value in a Variable
The concrete value of a Variable can be obtained through .value
.
a2.value
array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844])
(a2.value == a1).all()
True
b2.value
DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
Supported operations on a Variable
A Variable support almost all the operations for a tensor. Actually, brainpy.math.Variable
is a subclass of brainpy.math.ndarray
.
isinstance(a2, bp.math.numpy.ndarray)
True
isinstance(b2, bp.math.jax.ndarray)
True
isinstance(b2, bp.math.jax.JaxArray)
True
# `bp.math.jax.ndarray` is an alias for `bp.math.jax.JaxArray` in 'jax' backend
bp.math.jax.ndarray is bp.math.jax.JaxArray
True
Note
In ‘jax’ backend, after performing any operation on a Variable, the resulting value will be a JaxArray (bp.math.jax.ndarray
is an alias for bp.math.jax.JaxArray
in ‘jax’ backend). This means that the Variable can only be used to refer to a value.
b2 + 1.
JaxArray(DeviceArray([1.7053047, 1.9984136, 1.815271 , 1.926391 , 1.84018 ], dtype=float32))
b2 ** 2
JaxArray(DeviceArray([0.4974548 , 0.9968296 , 0.66466683, 0.8582003 , 0.7059025 ], dtype=float32))
bp.math.jax.floor(b2)
JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))
Subtypes of Variable
brainpy.math.Variable
has several subtypes, including brainpy.math.TrainVar
and brainpy.math.Parameter
. Subtypes can also be customized and extended by the user. We are going to talk about this.
TrainVar
brainpy.math.TrainVar
is a trainable variable (a subclass of brainpy.math.Variable
). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose.
bm.use_backend('numpy')
a = bm.random.rand(4)
a, bm.TrainVar(a)
(array([0.81515042, 0.40363449, 0.89924935, 0.29827197]),
TrainVar([0.81515042, 0.40363449, 0.89924935, 0.29827197]))
bm.use_backend('jax')
b = bm.random.rand(4)
b, bm.TrainVar(b)
(JaxArray(DeviceArray([0.4008 , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)),
TrainVar(DeviceArray([0.4008 , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)))
Parameter
brainpy.math.Parameter
is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable
. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets
method (please see Base class).
bm.use_backend('numpy')
a = bm.random.rand(1)
a, bm.Parameter(a)
(array([0.5776296]), Parameter([0.5776296]))
bm.use_backend('jax')
b = bm.random.rand(1)
b, bm.Parameter(b)
(JaxArray(DeviceArray([0.61128676], dtype=float32)),
Parameter(DeviceArray([0.61128676], dtype=float32)))
RandomState
In ‘jax’ backend, brainpy.math.random.RandomState
is also a subclass of brainpy.math.Variable
. This is because the RandomState in ‘jax’ backend must store the dynamically changed key information. Every time after a RandomState performs a random sampling, the “key” will change. For example,
bm.use_backend('jax')
state = bm.random.RandomState(seed=1234)
state
RandomState(DeviceArray([ 0, 1234], dtype=uint32))
# perform a "random" sampling
state.random(1)
# the value changed
state
RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling
state.sample(1)
# the value changed too
state
RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))
Every instance of RandomState can create a new seed from the current seed with .split_key()
.
state.split_key()
DeviceArray([3028232624, 826525938], dtype=uint32)
It can also create multiple seeds from the current seed with .split_keys(n)
. This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.
state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)
There is a default RandomState in brainpy.math.jax.random
module: DEFAULT
.
bm.random.DEFAULT
RandomState(DeviceArray([2580684476, 2503630841], dtype=uint32))
The inherent random methods like randint()
, rand()
, shuffle()
, etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed()
method.
bm.random.seed(654321)
bm.random.DEFAULT
RandomState(DeviceArray([ 0, 654321], dtype=uint32))