Tensors

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to understand:

  • What is a tensor?

  • How to create a tensor?

  • What operations are supported for a tensor?

If you have the basic knowledge about NumPy (the tensor here is the same as the ndarray in NumPy), you can skip this section.

import brainpy.math as bm

bm.set_platform('cpu')

What is tensor?

A tensor is a homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. The dimensions of an array are called axes.

In the following picture, the 1D array ([7, 2, 9, 10]) only has one axis. There are 4 elements in this axis, so the shape of the array is (4,).

By contrast, the 2D array

[[5.2, 3.0, 4.5], 
 [9.1, 0.1, 0.3]]

has 2 axes. The first axis has a length of 2 and the second has a length of 3. Therefore, the shape of the 2D array is (2, 3).

Similarly, the 3D array has 3 axes, with the dimensions (4, 3, 2) in each axis, respectively.

A tensor has several important attributes:

  • .ndim: the number of axes (dimensions) of the tensor.

  • .shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

  • .size: the total number of elements of the tensor. This is equal to the product of the elements of shape.

  • .dtype: an object describing the type of the elements in the tensor. One can create or specify dtypes using standard Python types.

a = bm.arange(15).reshape((3, 5))

a
JaxArray(DeviceArray([[ 0,  1,  2,  3,  4],
                      [ 5,  6,  7,  8,  9],
                      [10, 11, 12, 13, 14]], dtype=int32))
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int32')

How to create a tensor?

There are several ways to create a tensor.

1. array(), zeros() and ones()

The basic method is to convert Python sequences into tensors by bm.array(). For example:

bm.array([2, 3, 4])
JaxArray(DeviceArray([2, 3, 4], dtype=int32))
bm.array([(1.5, 2, 3), (4, 5, 6)])
JaxArray(DeviceArray([[1.5, 2. , 3. ],
                      [4. , 5. , 6. ]], dtype=float32))

Often, the elements of an array are originally unknown, but its size is known. Therefore, you can use placeholder functions to create tensors, like:

# "bm.zeros()" creates an array full of zeros

bm.zeros((3, 4))
JaxArray(DeviceArray([[0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.]], dtype=float32))
# "bm.ones()" creates an array full of ones

bm.ones((3, 4))
JaxArray(DeviceArray([[1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.]], dtype=float32))

2. linspace() and arange()

Another two commonly used 1D array creation functions are bm.linspace() and bm.arange().

bm.arange() creates arrays with regularly incrementing values. It receives “start”, “end”, and “step” settings.

# if only one argument "A" are provided, the function will  
# recognize the "start = 0", "end = A", and "step = 1" .

bm.arange(10)  
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
# if two argument "A, B" are provided, the function will  
# recognize the "start = A", "end = B", and "step = 1" .

bm.arange(2, 10, dtype=float)
JaxArray(DeviceArray([2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32))
# if three argument "A, B, C" are provided, the function will  
# recognize the "start = A", "end = B", and "step = C" .

bm.arange(2, 3, 0.1)
JaxArray(DeviceArray([2.       , 2.1      , 2.1999998, 2.2999997, 2.3999996,
                      2.4999995, 2.5999994, 2.6999993, 2.7999992, 2.8999991],            dtype=float32))

Due to the finite floating-point precision, it is generally impossible to predict the number of elements obtained by bm.arange(). For this reason, it is usually better to use the function bm.linspace() that receives “start”, “end”, and “num” settings.

bm.linspace(2, 3, 10)
JaxArray(DeviceArray([2.       , 2.1111112, 2.2222223, 2.3333333, 2.4444447,
                      2.5555556, 2.6666665, 2.777778 , 2.8888888, 3.       ],            dtype=float32))

3. Random sampling

brainpy.math module provides convenient random sampling functions. This module contains some simple random data generation methods, some permutation and distribution functions, and random generator functions. Here I just give several examples.

  • brainpy.math.random.rand(d0, d1, ..., dn)

This function of random module is used to generate random numbers or values in a given shape.

bm.random.rand(5, 2)
JaxArray(DeviceArray([[0.07216692, 0.2877245 ],
                      [0.29642737, 0.43941212],
                      [0.9228879 , 0.14709306],
                      [0.8345591 , 0.8134085 ],
                      [0.6776234 , 0.42747045]], dtype=float32))
  • brainpy.math.random.randn(d0, d1, ..., dn)

This function of random module returns a sample from the “standard normal” distribution.

bm.random.randn(5, 2)
JaxArray(DeviceArray([[-1.2456564 ,  0.93986976],
                      [ 1.2722825 , -0.5604058 ],
                      [ 0.42648995,  1.2291526 ],
                      [ 0.7283678 ,  0.12521066],
                      [ 0.29101485,  2.0382316 ]], dtype=float32))
  • brainpy.math.random.randint(low, high[, size, dtype])

This function of random module is used to generate random integers from inclusive(low) to exclusive(high).

bm.random.randint(0, 3, size=10)  
JaxArray(DeviceArray([0, 0, 0, 1, 2, 0, 1, 1, 1, 1], dtype=int32))
  • brainpy.math.random.random([size])

This function of random module is used to generate random floats number in the half-open interval [0.0, 1.0).

bm.random.random((3, 2))
JaxArray(DeviceArray([[0.30104673, 0.08174968],
                      [0.46729672, 0.30544508],
                      [0.37684703, 0.7211865 ]], dtype=float32))

brainpy.math module also provides permutation functions.

  • brainpy.math.random.shuffle()

This function is used to modify a sequence in-place by shuffling its contents.

bm.random.shuffle( bm.arange(10) )  
JaxArray(DeviceArray([5, 9, 4, 0, 3, 2, 6, 8, 7, 1], dtype=int32))
  • brainpy.math.random.permutation()

This function permute a sequence randomly or return a permuted range.

bm.random.permutation( bm.arange(10) )  
JaxArray(DeviceArray([1, 4, 8, 5, 2, 7, 3, 9, 0, 6], dtype=int32))

brainpy.math module also provides functions to sample distributions.

  • beta(a, b[, size])

This function is used to draw samples from a Beta distribution.

bm.random.beta(2, 3, 10) 
JaxArray(DeviceArray([0.5404635 , 0.5812938 , 0.22380598, 0.20873289, 0.5460086 ,
                      0.11106081, 0.69154614, 0.28446677, 0.24081689, 0.6313805 ],            dtype=float32))
  • exponential([scale, size])

This function is used to draw sample from an exponential distribution.

bm.random.exponential(1, 10) 
JaxArray(DeviceArray([0.828246  , 0.21370666, 0.32322478, 1.7631028 , 2.4889555 ,
                      1.133313  , 0.37169302, 0.32336047, 0.34985116, 0.2314292 ],            dtype=float32))

More sampling methods please see random sampling functions.

4. Other methods

Moreover, there are many other methods we can use to create tensors, including:

  • Conversion from other Python structures (i.e. lists and tuples)

  • Intrinsic NumPy array creation functions (e.g. arange, ones, zeros, etc.)

  • Use of special library functions (e.g., random)

  • Replicating, joining, or mutating existing tensors

  • Reading arrays from disk, either from standard or custom formats

  • Creating arrays from raw bytes through the use of strings or buffers

For details of these methods, please see NumPy tutorial: Array creation. Most of these methods are supported in BrainPy.

Supported operations on tensor

All the operations in BrainPy are based on tensors. Therefore it is necessary to know what operations supported in each tensor object.

Basic operations

Arithmetic operators on tensors apply element-wise. Let’s take “+”, “-“, “*”, and “/” as examples.

We first create two tensors:

data = bm.array([1, 2])

data
JaxArray(DeviceArray([1, 2], dtype=int32))
ones = bm.ones(2)

ones
JaxArray(DeviceArray([1., 1.], dtype=float32))

data + ones
JaxArray(DeviceArray([2., 3.], dtype=float32))

data - ones
JaxArray(DeviceArray([0., 1.], dtype=float32))
data * data
JaxArray(DeviceArray([1, 4], dtype=int32))
data / data
JaxArray(DeviceArray([1., 1.], dtype=float32))

Aggregation functions can also be performed on tensors, like:

  • .min(): to get the minimum element;

  • .max(): to get the maximum element;

  • .sum(): to get the summation;

  • .mean(): to get the average;

  • .prod(): to get the result of multiplying the elements together;

  • .std(): to get the standard deviation.

data = bm.array([1, 2, 3])

data.max()
DeviceArray(3, dtype=int32)
data.min()
DeviceArray(1, dtype=int32)
data.sum()
DeviceArray(6, dtype=int32)

It’s very common to specify the aggregation along a row or column. For example, you can find the maximum value within each column by specifying axis=0.

a = bm.array([[1, 2],
              [5, 3],
              [4, 6]])
a.max(axis=0)
JaxArray(DeviceArray([5, 6], dtype=int32))
a.max(axis=1)
JaxArray(DeviceArray([2, 5, 6], dtype=int32))

Broadcasting

Tensor operations are usually done on pairs of arrays on an element-by-element basis. In the simplest case, the two tensors must have exactly the same shape, as in the following example:

a = bm.array([1.0, 2.0, 3.0])
b = bm.array([2.0, 2.0, 2.0])

a * b
JaxArray(DeviceArray([2., 4., 6.], dtype=float32))

However, the broadcasting rule may be relaxed when the shapes of the tensors meet certain constraints. The simplest broadcasting example occurs when a tensor and a scalar value are combined in an operation:

a = bm.array([1, 2])
b = 1.6

a * b
JaxArray(DeviceArray([1.6, 3.2], dtype=float32, weak_type=True))

Similarly, broadcasting can be applied to matrices:

data = bm.array([[1, 2],
                 [3, 4],
                 [5, 6]])

ones_row = bm.array([[1, 1]])

data + ones_row
JaxArray(DeviceArray([[2, 3],
                      [4, 5],
                      [6, 7]], dtype=int32))

Under certain constraints, the smaller tensor can be “broadcast” across the larger tensor so that they have compatible shapes. Broadcasting provides a means of vectorizing tensor operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations.

Generally, the dimensions of two tensors are compatible when

  • they are equal,

  • one of them is 1, or

  • one of them has fewer dimensions

If these conditions are not met, an error will occur.

For example, according to the broadcasting rules, the following two shapes are compatible:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3
image = bm.random.random((256, 256, 3))
scale = bm.random.random(3)

_ = image + scale 
_ = image - scale 
_ = image * scale 
_ = image / scale 

These shapes are also compatible:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
A = bm.random.random((8, 1, 6, 1))
B = bm.random.random((7, 1, 5))

_ = A + B 
_ = A - B 
_ = A * B 
_ = A / B 

However, in these examples, the shapes cannot be broadcast:

A      (1d array):  3
B      (1d array):  4 # trailing dimensions do not match

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # second from last dimensions mismatched
A = bm.random.random((3,))
B = bm.random.random((4,))

try:
    _ = A + B
except Exception as e:
    print(e)
add got incompatible shapes for broadcasting: (3,), (4,).
A = bm.random.random((2, 1))
B = bm.random.random((8, 4, 3))

try:
    _ = A + B
except Exception as e:
    print(e)
Incompatible shapes for broadcasting: ((1, 2, 1), (8, 4, 3))

For more details about broadcasting, please see NumPy documentation: broadcasting.

Indexing, Slicing and Iterating

Tensors can be indexed, sliced, and iterated over, much like lists and other Python sequences. For examples:

a = bm.arange(10) ** 3

a
JaxArray(DeviceArray([  0,   1,   8,  27,  64, 125, 216, 343, 512, 729], dtype=int32))
a[2]
DeviceArray(8, dtype=int32)
a[2:5]
DeviceArray([ 8, 27, 64], dtype=int32)
# from start to position 6, exclusive, set every 2nd element to 1000,
# equivalent to a[0:6:2] = 1000

a[:6:2] = 1000

a
JaxArray(DeviceArray([1000,    1, 1000,   27, 1000,  125,  216,  343,  512,  729], dtype=int32))
a[::-1]  # reversed a
DeviceArray([ 729,  512,  343,  216,  125, 1000,   27, 1000,    1, 1000], dtype=int32)
for i in a:  # iterate a
    print(i**(1 / 3.))
10.000001
1.0
10.000001
3.0
10.000001
5.0000005
6.0000005
7.0000005
8.000001
9.000001

For multi-dimensional tensors, these indices should be given in a tuple separated by commas. For example:

b = bm.arange(20).reshape((5, 4))
b[2, 3]
DeviceArray(11, dtype=int32)
b[0:5, 1]  # each row in the second column of b
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[:, 1]    # equivalent to the previous example
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[1:3, :]  # each column in the second and third row of b
DeviceArray([[ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

b[-1]   # the last row. Equivalent to b[-1, :]
DeviceArray([16, 17, 18, 19], dtype=int32)

You can also write this using dots as b[i, ...]. The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is an array with 5 axes, then

  • x[1, 2, ...] is equivalent to x[1, 2, :, :, :],

  • x[..., 3] to x[:, :, :, :, 3] and

  • x[4, ..., 5, :] to x[4, :, :, 5, :].

c = bm.arange(48).reshape((6, 4, 2))
c[1, ...]  # same as c[1, :, :] or c[1]
DeviceArray([[ 8,  9],
             [10, 11],
             [12, 13],
             [14, 15]], dtype=int32)
c[..., 1]  # same as c[:, :, 1]
DeviceArray([[ 1,  3,  5,  7],
             [ 9, 11, 13, 15],
             [17, 19, 21, 23],
             [25, 27, 29, 31],
             [33, 35, 37, 39],
             [41, 43, 45, 47]], dtype=int32)

Iterating over multidimensional tensors is done with respect to the first axis:

for row in b:
    print(row)
[0 1 2 3]
[4 5 6 7]
[ 8  9 10 11]
[12 13 14 15]
[16 17 18 19]

For more methods or advanced indexing and index tricks, please see NumPy tutorial: Indexing.

Mathematical functions

Tensors support many other functions, including

Most of these functions can be found in brainpy.math module. Let’s take a look at trigonometric, hyperbolic, and rounding functions.

d = bm.linspace(0, 1, 10)

d
JaxArray(DeviceArray([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
                      0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],            dtype=float32))
# trigonometric functions

bm.sin(d)
JaxArray(DeviceArray([0.        , 0.11088263, 0.22039774, 0.32719472, 0.42995638,
                      0.5274154 , 0.6183698 , 0.7016979 , 0.7763719 , 0.84147096],            dtype=float32))
bm.arcsin(d)
JaxArray(DeviceArray([0.        , 0.11134101, 0.2240931 , 0.3398369 , 0.460554  ,
                      0.58903104, 0.7297277 , 0.8911225 , 1.0949141 , 1.5707964 ],            dtype=float32))
# hyperbolic functions

bm.sinh(d)
JaxArray(DeviceArray([0.        , 0.11133985, 0.22405571, 0.33954054, 0.45922154,
                      0.58457786, 0.7171585 , 0.8586021 , 1.0106566 , 1.1752012 ],            dtype=float32))
# rounding functions

bm.round(d)
JaxArray(DeviceArray([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], dtype=float32))
# sum function

bm.sum(d)
DeviceArray(5., dtype=float32)