Arrays and Variables#

Colab Open in Kaggle

@Xiaoyu Chen @Chaoming Wang

In this section ,we will briefly introduce two basic and important data structures: arrays and variables. They form the foundation for mathematical operations of brain dynamics programming (BDP) in BrainPy.

Arrays#

Definition and Attributes#

An array is a data structure that organizes algebraic objects in a multidimentional vector space. Simply speaking, in BrainPy, an array is a multidimensional array that contains the same type of data, most commonly of the numeric or boolean type.

The dimensions of an array are called axes. In the following illustration, the 1-D array ([7, 2, 9, 10]) only has one axis. There are 4 elements in this axis, so the shape of the array is (4,).

By contrast, the 2-D array in the illustration has 2 axes. The first axis is of length 2 and the second of length 3. Therefore, the shape of the 2-D array is (2, 3).

Similarly, the 3-D array has 3 axes, with the dimensions (4, 3, 2) in each axis, respectively.

To enable array operations, users should import the brainpy.math module:

import brainpy.math as bm

# bm.set_platform('cpu')
t1 = bm.array([[[0, 1, 2, 3], [1, 2, 3, 4], [4, 5, 6, 7]], 
               [[0, 0, 0, 0], [-1, 1, -1, 1], [2, -2, 2, -2]]])
t1
JaxArray([[[ 0,  1,  2,  3],
           [ 1,  2,  3,  4],
           [ 4,  5,  6,  7]],

          [[ 0,  0,  0,  0],
           [-1,  1, -1,  1],
           [ 2, -2,  2, -2]]], dtype=int32)

Here we create a 3-dimensional array with the shape of (2, 3, 4) and the type of int32. Arrays created by brainpy.math will be stored in JaxArray, for their future operations will be accelerated by just-in-time (JIT) compilation.

A array has several important attributes:

  • .ndim: the number of axes (dimensions) of the array.

  • .shape: the dimensions of the array. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

  • .size: the total number of elements of the array. This is equal to the product of the elements of shape.

  • .dtype: an object describing the type of the elements in the array. One can create or specify dtypes using standard Python types.

print('t1.ndim: {}'.format(t1.ndim))
print('t1.shape: {}'.format(t1.shape))
print('t1.size: {}'.format(t1.size))
print('t1.dtype: {}'.format(t1.dtype))
t1.ndim: 3
t1.shape: (2, 3, 4)
t1.size: 24
t1.dtype: int32

Below we will give a few examples of array operations that are commonly used in brain dynamics programming. For more details about array operations, please refer to the array tutorial.

Creating a array#

t2 = bm.arange(4)
# t2: JaxArray([0, 1, 2, 3], dtype=int32)

t3 = bm.ones((2, 4)) * 1.5
# t3: JaxArray([[1.5, 1.5, 1.5, 1.5],
#               [1.5, 1.5, 1.5, 1.5]], dtype=float32)

Array operations#

# indexing and slicing
t3[1]
# DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)

t3[:, 2:]
# DeviceArray([[1.5, 1.5],
#              [1.5, 1.5]], dtype=float32)
DeviceArray([[1.5, 1.5],
             [1.5, 1.5]], dtype=float32)
# algebraic operations
t2 + t3[0]
# JaxArray([1.5, 2.5, 3.5, 4.5], dtype=float32)

t3[0] / t1[0, 1]
# DeviceArray([1.5  , 0.75 , 0.5  , 0.375], dtype=float32)

# broadcasting
t2 + 3
# JaxArray([3, 4, 5, 6], dtype=int32)

t2 + t3
# JaxArray([[1.5, 2.5, 3.5, 4.5],
#           [1.5, 2.5, 3.5, 4.5]], dtype=float32)
JaxArray([[1.5, 2.5, 3.5, 4.5],
          [1.5, 2.5, 3.5, 4.5]], dtype=float32)
# some functions
bm.dot(t2, t3.T)
# JaxArray([9., 9.], dtype=float32)

bm.max(t1, axis=2)
# JaxArray([[3, 4, 7],
#           [0, 1, 2]], dtype=int32)

t3.flatten()
# JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)
JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)

In BrainPy, arrays can be used to store some parameters related to dynamical models. For example, if we define a group of Integrate-and-Fire (LIF) neurons and wish to assign each neuron with a different time constant \(\tau\), then we can generate an array containing an array of time constants.

n = 6  # assume there are 6 LIF neurons
tau = bm.random.randn(n)*2. + 20.
tau
JaxArray([18.485964, 19.765427, 15.078529, 21.210836, 17.134335,
          21.495173], dtype=float32)

Through the code above, a group of time constants is generated from a normal distribution, with a mean of 20 and a variance of 2.

Variables#

We have talked about the definition, operations, and application of arrays in BrainPy. There are some situations, however, where arrays are not applicable. Due to JIT compilation, once a array is given to the JIT compiler, the values inside the array cannot be changed. This gives rise to severe limitations, because some properties of the dynamical system, such as the membrane potential, dynamically changes over time. Therefore, we need a new data structure to store such dynamic variables, and that is brainpy.math.Variable.

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a array. The array is stored as its value. The data in a Variable can be changed during JIT compilation. If a array is labeled as a Variable, it means that it is a dynamical variable that changes over time.

To create or change a array into a variable, users just need to wrap the array into brainpy.math.Variable:

v = bm.Variable(t2)
v
Variable([0, 1, 2, 3], dtype=int32)

Note that the array is contained in a “Variable” instead of a “JaxArray”.

Note

Arrays that are not marked as Variables will be JIT compiled as static data. In JIT compilation, it is shown that modifications of arrays are invalid in a JIT-compilation environment.

Users can access the value in the Variable through its attribute .value:

v.value
DeviceArray([0, 1, 2, 3], dtype=int32)

Since the data inside a Variable is a array, common operations on arrays can be directly grafted to Variables.

In-place updating#

Though the operations are the same, there are some requirements for updating a Variable. If we directly change a Variable, The returning data will become a array but not a Variable.

v2 = v + 2
v2
JaxArray([2, 3, 4, 5], dtype=int32)

To update the Variable, users are required to use in-place updating, which only modifies the value inside the Variable but does not change the reference pointing to the Variable. In-place updating operations include:

1. Indexing and slicing

  • Indexing: v[i] = a

  • Slicing: v[i:j] = b

  • Slicing the specific values: v[[1, 3]] = c

  • Slicing all values, v[:] = d, v[...] = e

for more details, please refer to Array Objects Indexing.

v[0] = 10
v[1:3] = 9
v
Variable([10,  9,  9,  3], dtype=int32)

2. Augmented assignment

  • += (add)

  • -= (subtract)

  • /= (divide)

  • *= (multiply)

  • //= (floor divide)

  • %= (modulo)

  • **= (power)

  • &= (and)

  • |= (or)

  • ^= (xor)

  • <<= (left shift)

  • >>= (right shift)

v -= 3
v <<= 1
v
Variable([14, 12, 12,  0], dtype=int32)

3. .value assignment

v.value = bm.arange(4)
v
Variable([0, 1, 2, 3], dtype=int32)

.value assignment directly accesses the data stored in the JaxArray. When using .value, the new data should be of the same type and shape as the original ones.

try:
    v.value = bm.array([1., 1., 1., 0.])
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

4. .update() method

This method will also check if the new data is of the same type and shape as the original ones.

v.update(bm.array([3, 4, 5, 6]))
v
Variable([3, 4, 5, 6], dtype=int32)

For more details, such as the subtypes of Variables and more information about in-place updating, please see the advanced tutorial for Variables.