# Tensors and Variables#

In this section ,we will briefly introduce two basic and important data structures: tensors and variables. They form the foundation for mathematical operations of brain dynamics programming (BDP) in BrainPy.

## Tensors#

### Definition and Attributes#

A tensor is a data structure that organizes algebraic objects in a multidimentional vector space. Simply speaking, in BrainPy, a tensor is a multidimensional array that contains the same type of data, most commonly of the numeric or boolean type.

The dimensions of an array are called axes. In the following illustration, the 1-D array ([7, 2, 9, 10]) only has one axis. There are 4 elements in this axis, so the shape of the array is (4,).

By contrast, the 2-D array in the illustration has 2 axes. The first axis is of length 2 and the second of length 3. Therefore, the shape of the 2-D array is (2, 3).

Similarly, the 3-D array has 3 axes, with the dimensions (4, 3, 2) in each axis, respectively.

To enable tensor operations, users should import the brainpy.math module:

import brainpy.math as bm

# bm.set_platform('cpu')

t1 = bm.array([[[0, 1, 2, 3], [1, 2, 3, 4], [4, 5, 6, 7]],
[[0, 0, 0, 0], [-1, 1, -1, 1], [2, -2, 2, -2]]])
t1

JaxArray([[[ 0,  1,  2,  3],
[ 1,  2,  3,  4],
[ 4,  5,  6,  7]],

[[ 0,  0,  0,  0],
[-1,  1, -1,  1],
[ 2, -2,  2, -2]]], dtype=int32)


Here we create a 3-dimensional tensor with the shape of (2, 3, 4) and the type of int32. Tensors created by brainpy.math will be stored in JaxArray, for their future operations will be accelerated by just-in-time (JIT) compilation.

A tensor has several important attributes:

• .ndim: the number of axes (dimensions) of the tensor.

• .shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

• .size: the total number of elements of the tensor. This is equal to the product of the elements of shape.

• .dtype: an object describing the type of the elements in the tensor. One can create or specify dtypes using standard Python types.

print('t1.ndim: {}'.format(t1.ndim))
print('t1.shape: {}'.format(t1.shape))
print('t1.size: {}'.format(t1.size))
print('t1.dtype: {}'.format(t1.dtype))

t1.ndim: 3
t1.shape: (2, 3, 4)
t1.size: 24
t1.dtype: int32


Below we will give a few examples of tensor operations that are commonly used in brain dynamics programming. For more details about tensor operations, please refer to the tensor tutorial.

### Creating a tensor#

t2 = bm.arange(4)
# t2: JaxArray([0, 1, 2, 3], dtype=int32)

t3 = bm.ones((2, 4)) * 1.5
# t3: JaxArray([[1.5, 1.5, 1.5, 1.5],
#               [1.5, 1.5, 1.5, 1.5]], dtype=float32)


### Tensor operations#

# indexing and slicing
t3[1]
# DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)

t3[:, 2:]
# DeviceArray([[1.5, 1.5],
#              [1.5, 1.5]], dtype=float32)

DeviceArray([[1.5, 1.5],
[1.5, 1.5]], dtype=float32)

# algebraic operations
t2 + t3[0]
# JaxArray([1.5, 2.5, 3.5, 4.5], dtype=float32)

t3[0] / t1[0, 1]
# DeviceArray([1.5  , 0.75 , 0.5  , 0.375], dtype=float32)

t2 + 3
# JaxArray([3, 4, 5, 6], dtype=int32)

t2 + t3
# JaxArray([[1.5, 2.5, 3.5, 4.5],
#           [1.5, 2.5, 3.5, 4.5]], dtype=float32)

JaxArray([[1.5, 2.5, 3.5, 4.5],
[1.5, 2.5, 3.5, 4.5]], dtype=float32)

# some functions
bm.dot(t2, t3.T)
# JaxArray([9., 9.], dtype=float32)

bm.max(t1, axis=2)
# JaxArray([[3, 4, 7],
#           [0, 1, 2]], dtype=int32)

t3.flatten()
# JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)

JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)


In BrainPy, tensors can be used to store some parameters related to dynamical models. For example, if we define a group of Integrate-and-Fire (LIF) neurons and wish to assign each neuron with a different time constant $$\tau$$, then we can generate a tensor containing an array of time constants.

n = 6  # assume there are 6 LIF neurons
tau = bm.random.randn(n)*2. + 20.
tau

JaxArray([18.485964, 19.765427, 15.078529, 21.210836, 17.134335,
21.495173], dtype=float32)


Through the code above, a group of time constants is generated from a normal distribution, with a mean of 20 and a variance of 2.

## Variables#

We have talked about the definition, operations, and application of tensors in BrainPy. There are some situations, however, where tensors are not applicable. Due to JIT compilation, once a tensor is given to the JIT compiler, the values inside the tensor cannot be changed. This gives rise to severe limitations, because some properties of the dynamical system, such as the membrane potential, dynamically changes over time. Therefore, we need a new data structure to store such dynamic variables, and that is brainpy.math.Variable.

### brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a tensor. The tensor is stored as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.

To create or change a tensor into a variable, users just need to wrap the tensor into brainpy.math.Variable:

v = bm.Variable(t2)
v

Variable([0, 1, 2, 3], dtype=int32)


Note that the array is contained in a “Variable” instead of a “JaxArray”.

Note

Tensors that are not marked as Variables will be JIT compiled as static data. In JIT compilation, it is shown that modifications of tensors are invalid in a JIT-compilation environment.

Users can access the value in the Variable through its attribute .value:

v.value

DeviceArray([0, 1, 2, 3], dtype=int32)


Since the data inside a Variable is a tensor, common operations on tensors can be directly grafted to Variables.

## In-place updating#

Though the operations are the same, there are some requirements for updating a Variable. If we directly change a Variable, The returning data will become a tensor but not a Variable.

v2 = v + 2
v2

JaxArray([2, 3, 4, 5], dtype=int32)


To update the Variable, users are required to use in-place updating, which only modifies the value inside the Variable but does not change the reference pointing to the Variable. In-place updating operations include:

1. Indexing and slicing

• Indexing: v[i] = a

• Slicing: v[i:j] = b

• Slicing the specific values: v[[1, 3]] = c

• Slicing all values, v[:] = d, v[...] = e

for more details, please refer to Array Objects Indexing.

v[0] = 10
v[1:3] = 9
v

Variable([10,  9,  9,  3], dtype=int32)


2. Augmented assignment

• += (add)

• -= (subtract)

• /= (divide)

• *= (multiply)

• //= (floor divide)

• %= (modulo)

• **= (power)

• &= (and)

• |= (or)

• ^= (xor)

• <<= (left shift)

• >>= (right shift)

v -= 3
v <<= 1
v

Variable([14, 12, 12,  0], dtype=int32)


3. .value assignment

v.value = bm.arange(4)
v

Variable([0, 1, 2, 3], dtype=int32)


 .value assignment directly accesses the data stored in the JaxArray. When using .value, the new data should be of the same type and shape as the original ones.

try:
v.value = bm.array([1., 1., 1., 0.])
except Exception as e:
print(type(e), e)

<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.


4. .update() method

This method will also check if the new data is of the same type and shape as the original ones.

v.update(bm.array([3, 4, 5, 6]))
v

Variable([3, 4, 5, 6], dtype=int32)