Tensors

In this section, we are going to understand:

  • what is tensor?

  • how to create tensor?

  • what operations are supported for a tensor?

import brainpy.math as bm

What is tensor?

A tensor is a homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. The dimensions of an array are called axes.

In the following picture, the 1D array ([7, 2, 9, 10]) only has one axis. That axis has 4 elements in it, so we say it has a shape of (4,).

While, the 2D array

[[5.2, 3.0, 4.5], 
 [9.1, 0.1, 0.3]]

has 2 axes. The first axis has a length of 2, the second axis has a length of 3. So, we say it has a shape of (2, 3).

Similarly, the 3D array has 3 axes, with dimensions in each axis is (4, 3, 2).

Each tensor has several important attributes:

  • .ndim: the number of axes (dimensions) of the tensor.

  • .shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

  • .size: the total number of elements of the tensor. This is equal to the product of the elements of shape.

  • .dtype: an object describing the type of the elements in the tensor. One can create or specify dtype’s using standard Python types.

In ‘numpy’ backend, the tensor is exactly the same as the tensor in NumPy. For example:

bm.use_backend('numpy')

a = bm.arange(15).reshape((3, 5))

a
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int64')

However, in ‘jax’ backend, we wrap the original jax.numpy.ndarray, and create a new data structure JaxArray. However, the attributes and operations are the same with the NumPy tensors. For example:

bm.use_backend('jax')

a = bm.arange(15).reshape((3, 5))

a
JaxArray(DeviceArray([[ 0,  1,  2,  3,  4],
                      [ 5,  6,  7,  8,  9],
                      [10, 11, 12, 13, 14]], dtype=int32))
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int32')

How to create tensor?

There are several ways to create tensors. Methods for tensor creation are same under “numpy” and “jax” backends.

array(), zeros() and ones()

The basic method is to convert Python sequences into tensors by bm.array(). For example:

bm.array([2, 3, 4])
JaxArray(DeviceArray([2, 3, 4], dtype=int32))
bm.array([(1.5, 2, 3), (4, 5, 6)])
JaxArray(DeviceArray([[1.5, 2. , 3. ],
                      [4. , 5. , 6. ]], dtype=float32))

Often, the elements of an array are originally unknown, but its size is known. Therefore, you can use placeholder functions to create tensors, like:

# "bm.zeros()" creates an array full of zeros

bm.zeros((3, 4))
JaxArray(DeviceArray([[0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.]], dtype=float32))
# "bm.ones()" creates an array full of ones

bm.ones((3, 4))
JaxArray(DeviceArray([[1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.]], dtype=float32))

linspace() and arange()

Another two commonly used 1D array creation functions are bm.linspace() and bm.arange().

bm.arange() creates arrays with regularly incrementing values. It receives “start”, “end”, and “step” settings.

# if only one argument "A" are provided, the function will  
# recognize the "start = 0", "end = A", and "step = 1" .

bm.arange(10)  
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
# if two argument "A, B" are provided, the function will  
# recognize the "start = A", "end = B", and "step = 1" .

bm.arange(2, 10, dtype=float)
JaxArray(DeviceArray([2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32))
# if three argument "A, B, C" are provided, the function will  
# recognize the "start = A", "end = B", and "step = C" .

bm.arange(2, 3, 0.1)
JaxArray(DeviceArray([2.       , 2.1      , 2.1999998, 2.2999997, 2.3999996,
                      2.4999995, 2.5999994, 2.6999993, 2.7999992, 2.8999991],            dtype=float32))

Due to the finite floating point precision, it is generally not possible to predict the number of elements obtained by bm.arange(). For this reason, it is usually better to use the function bm.linspace() that receives “start”, “end”, and “num” settings.

bm.linspace(2, 3, 10)
JaxArray(DeviceArray([2.       , 2.1111112, 2.2222223, 2.3333333, 2.4444447,
                      2.5555556, 2.6666665, 2.777778 , 2.8888888, 3.       ],            dtype=float32))

Random sampling

brainpy.math module provides convenient random sampling functions. This module contains some simple random data generation methods, some permutation and distribution functions, and random generator functions. Here I just give several examples.

  • brainpy.math.random.rand(d0, d1, ..., dn)

This function of random module is used to generate random numbers or values in a given shape.

bm.random.rand(5, 2)
JaxArray(DeviceArray([[0.99398685, 0.39656162],
                      [0.5161425 , 0.81978667],
                      [0.31676686, 0.083583  ],
                      [0.16560888, 0.40949285],
                      [0.43086028, 0.22965682]], dtype=float32))
  • brainpy.math.random.randn(d0, d1, ..., dn)

This function of random module return a sample from the “standard normal” distribution.

bm.random.randn(5, 2)
JaxArray(DeviceArray([[-0.7701253 ,  0.00965391],
                      [-0.11301948,  0.1409633 ],
                      [-0.11914475,  0.068143  ],
                      [ 1.6409276 ,  1.3378068 ],
                      [ 1.8202178 , -0.37819335]], dtype=float32))
  • brainpy.math.random.randint(low, high[, size, dtype])

This function of random module is used to generate random integers from inclusive(low) to exclusive(high).

bm.random.randint(0, 3, size=10)  
JaxArray(DeviceArray([0, 1, 1, 2, 0, 1, 0, 2, 0, 2], dtype=int32))
  • brainpy.math.random.random([size])

This function of random module is used to generate random floats number in the half-open interval [0.0, 1.0).

bm.random.random((3, 2))
JaxArray(DeviceArray([[0.76483357, 0.559957  ],
                      [0.50227726, 0.41693842],
                      [0.65068877, 0.8199152 ]], dtype=float32))

brainpy.math module also provides permutation functions.

  • brainpy.math.random.shuffle()

This function is used for modifying a sequence in-place by shuffling its contents.

bm.random.shuffle( bm.arange(10) )  
JaxArray(DeviceArray([8, 4, 9, 5, 7, 0, 3, 6, 1, 2], dtype=int32))
  • brainpy.math.random.permutation()

This function permute a sequence randomly or return a permuted range.

bm.random.permutation( bm.arange(10) )  
JaxArray(DeviceArray([2, 1, 0, 9, 5, 4, 7, 6, 8, 3], dtype=int32))

brainpy.math module also provides functions to sample distributions.

  • beta(a, b[, size])

This function is used to draw samples from a Beta distribution.

bm.random.beta(2, 3, 10) 
JaxArray(DeviceArray([0.48173192, 0.09183226, 0.5617174 , 0.4964077 , 0.5717186 ,
                      0.60861576, 0.3472139 , 0.58446443, 0.41256   , 0.07920451],            dtype=float32))
  • exponential([scale, size])

This function is used to draw sample from an exponential distribution.

bm.random.exponential(1, 10) 
JaxArray(DeviceArray([1.5618182 , 0.18306465, 1.0619484 , 1.2519189 , 0.6019476 ,
                      1.0401233 , 0.37211612, 0.06336975, 3.796705  , 0.03766083],            dtype=float32))

More sampling methods please see random sampling functions.

And more

Moreover, there are many other methods we can use to create tensors, including:

  • Conversion from other Python structures (i.e. lists and tuples)

  • Intrinsic NumPy array creation functions (e.g. arange, ones, zeros, etc.)

  • Use of special library functions (e.g., random)

  • Replicating, joining, or mutating existing tensors

  • Reading arrays from disk, either from standard or custom formats

  • Creating arrays from raw bytes through the use of strings or buffers

Detail of these methods please see NumPy tutorial: Array creation. Most of these methods are supported in BrainPy.

Supported operations on tensor

All the operations in BrainPy are based on tensors. Therefore it is necessary to know what operations supported in each tensor object.

Basic operations

Arithmetic operators on tensors apply element-wise. Let’s take “+”, “-”, “*”, and “/” as examples.

We first create two tensors:

data = bm.array([1, 2])

data
JaxArray(DeviceArray([1, 2], dtype=int32))
ones = bm.ones(2)

ones
JaxArray(DeviceArray([1., 1.], dtype=float32))

data + ones
JaxArray(DeviceArray([2., 3.], dtype=float32))

data - ones
JaxArray(DeviceArray([0., 1.], dtype=float32))
data * data
JaxArray(DeviceArray([1, 4], dtype=int32))
data / data
JaxArray(DeviceArray([1., 1.], dtype=float32))

Aggregation functions can also be performed on tensors, like:

  • .min(): get the minimum element;

  • .max(): get the maximum element;

  • .sum(): get the summation;

  • .mean(): get the average;

  • .prod(): get the result of multiplying the elements together;

  • .std(): to get the standard deviation.

data = bm.array([1, 2, 3])

data.max()
DeviceArray(3, dtype=int32)
data.min()
DeviceArray(1, dtype=int32)
data.sum()
DeviceArray(6, dtype=int32)

It’s very common to want to aggregate along a row or column. You can specify on which axis you want the aggregation function to be computed. For example, you can find the maximum value within each column by specifying axis=0.

a = bm.array([[1, 2],
              [5, 3],
              [4, 6]])
a.max(axis=0)
JaxArray(DeviceArray([5, 6], dtype=int32))
a.max(axis=1)
JaxArray(DeviceArray([2, 5, 6], dtype=int32))

Broadcasting

Tensor operations are usually done on pairs of arrays on an element-by-element basis. In the simplest case, the two tensors must have exactly the same shape, as in the following example:

a = bm.array([1.0, 2.0, 3.0])
b = bm.array([2.0, 2.0, 2.0])

a * b
JaxArray(DeviceArray([2., 4., 6.], dtype=float32))

However, broadcasting rule relaxes this constraint when the tensor’ shapes meet certain constraints. The simplest broadcasting example occurs when an tensor and a scalar value are combined in an operation:

a = bm.array([1, 2])
b = 1.6

a * b
JaxArray(DeviceArray([1.6, 3.2], dtype=float32, weak_type=True))

Similarly, broadcasting can happens on matrix. Below is an example.

data = bm.array([[1, 2],
                 [3, 4],
                 [5, 6]])

ones_row = bm.array([[1, 1]])

data + ones_row
JaxArray(DeviceArray([[2, 3],
                      [4, 5],
                      [6, 7]], dtype=int32))

Under certain constraints, the smaller tensor can be “broadcast” across the larger tensor so that they have compatible shapes. Broadcasting provides a means of vectorizing tensor operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations.

Generally, the dimensions of two tensors are compatible when

  • they are equal, or

  • one of them is 1

  • one of them has less number of dimensions

If these conditions are not met, an error will happen.

For example, according to the broadcast rules, the following two shapes are compatible:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3
image = bm.random.random((256, 256, 3))
scale = bm.random.random(3)

_ = image + scale 
_ = image - scale 
_ = image * scale 
_ = image / scale 

These shapes are also compatible:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
A = bm.random.random((8, 1, 6, 1))
B = bm.random.random((7, 1, 5))

_ = A + B 
_ = A - B 
_ = A * B 
_ = A / B 

However, these examples of shapes do not broadcast:

A      (1d array):  3
B      (1d array):  4 # trailing dimensions do not match

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # second from last dimensions mismatched
A = bm.random.random((3,))
B = bm.random.random((4,))

try:
    _ = A + B
except Exception as e:
    print(e)
add got incompatible shapes for broadcasting: (3,), (4,).
A = bm.random.random((2, 1))
B = bm.random.random((8, 4, 3))

try:
    _ = A + B
except Exception as e:
    print(e)
Incompatible shapes for broadcasting: ((1, 2, 1), (8, 4, 3))

More details about broadcasting please see NumPy documentation: broadcasting.

Indexing, Slicing and Iterating

Any tensors can be indexed, sliced and iterated over, much like lists and other Python sequences. For examples:

a = bm.arange(10) ** 3

a
JaxArray(DeviceArray([  0,   1,   8,  27,  64, 125, 216, 343, 512, 729], dtype=int32))
a[2]
DeviceArray(8, dtype=int32)
a[2:5]
DeviceArray([ 8, 27, 64], dtype=int32)
# from start to position 6, exclusive, set every 2nd element to 1000,
# equivalent to a[0:6:2] = 1000

a[:6:2] = 1000

a
JaxArray(DeviceArray([1000,    1, 1000,   27, 1000,  125,  216,  343,  512,  729], dtype=int32))
a[::-1]  # reversed a
DeviceArray([ 729,  512,  343,  216,  125, 1000,   27, 1000,    1, 1000], dtype=int32)
for i in a:  # iterate a
    print(i**(1 / 3.))
10.000001
1.0
10.000001
3.0
10.000001
5.0000005
6.0000005
7.0000005
8.000001
9.000001

For multi-dimensional tensors, these indices should be given in a tuple separated by commas. For example,

b = bm.arange(20).reshape((5, 4))
b[2, 3]
DeviceArray(11, dtype=int32)
b[0:5, 1]  # each row in the second column of b
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[:, 1]    # equivalent to the previous example
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[1:3, :]  # each column in the second and third row of b
DeviceArray([[ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

b[-1]   # the last row. Equivalent to b[-1, :]
DeviceArray([16, 17, 18, 19], dtype=int32)

You can also write this using dots as b[i, ...]. The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is an array with 5 axes, then

  • x[1, 2, ...] is equivalent to x[1, 2, :, :, :],

  • x[..., 3] to x[:, :, :, :, 3] and

  • x[4, ..., 5, :] to x[4, :, :, 5, :].

c = bm.arange(48).reshape((6, 4, 2))
c[1, ...]  # same as c[1, :, :] or c[1]
DeviceArray([[ 8,  9],
             [10, 11],
             [12, 13],
             [14, 15]], dtype=int32)
c[..., 2]  # same as c[:, :, 2]
DeviceArray([[ 1,  3,  5,  7],
             [ 9, 11, 13, 15],
             [17, 19, 21, 23],
             [25, 27, 29, 31],
             [33, 35, 37, 39],
             [41, 43, 45, 47]], dtype=int32)

Iterating over multidimensional tensors is done with respect to the first axis:

for row in b:
    print(row)
[0 1 2 3]
[4 5 6 7]
[ 8  9 10 11]
[12 13 14 15]
[16 17 18 19]

More methods or advanced indexing and index tricks please see NumPy tutorial: Indexing.

Mathematical functions

Tensors support many other functions, including

Most of these functions can be found in brainpy.math module. Let’s take a look at trigonometric, hyperbolic, rounding functions.

d = bm.linspace(0, 1, 10)

d
JaxArray(DeviceArray([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
                      0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],            dtype=float32))
# trigonometric functions

bm.sin(d)
JaxArray(DeviceArray([0.        , 0.11088263, 0.22039774, 0.32719472, 0.42995638,
                      0.5274154 , 0.6183698 , 0.7016979 , 0.7763719 , 0.84147096],            dtype=float32))
bm.arcsin(d)
JaxArray(DeviceArray([0.        , 0.11134101, 0.2240931 , 0.33983693, 0.46055397,
                      0.589031  , 0.7297277 , 0.8911225 , 1.0949141 , 1.5707964 ],            dtype=float32))
# hyperbolic functions

bm.sinh(d)
JaxArray(DeviceArray([0.        , 0.11133985, 0.22405571, 0.33954054, 0.45922154,
                      0.58457786, 0.7171585 , 0.8586021 , 1.0106566 , 1.1752012 ],            dtype=float32))
# rounding functions

bm.round(d)
JaxArray(DeviceArray([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], dtype=float32))
# sum function

bm.sum(d)
DeviceArray(5., dtype=float32)