brainpy.math Overview#

@Chaoming Wang @Xiaoyu Chen

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. While they are designed to work only on pure Python functions, most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on class in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

In order to provide a platform can satisfy the need for brain dynamics programming, we provide the brainpy.math module.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
import numpy as np

Why use brainpy.math?#

Specifically, brainpy.math makes the following contributions:

1. Numpy-like ndarray.#

Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y. To overcome these drawbacks, brainpy.math provides JaxArray that can be used in the same way as numpy ndarray.

# ndarray in "numpy"

a = np.arange(5)
a
array([0, 1, 2, 3, 4])
a[0] += 5
a
array([5, 1, 2, 3, 4])
# ndarray in "brainpy.math"

b = bm.arange(5)
b
JaxArray([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
JaxArray([5, 1, 2, 3, 4], dtype=int32)

For more details, please see the Tensors tutorial.

2. Numpy-like random sampling.#

JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math provides brainpy.math.random for random sampling just like the numpy.random module. For example:

# random sampling in "numpy"

np.random.seed(12345)
np.random.random(5)
array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])
np.random.normal(0., 2., 5)
array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])
# random sampling in "brainpy.math.random"

bm.random.seed(12345)
bm.random.random(5)
JaxArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ],            dtype=float32)
bm.random.normal(0., 2., 5)
JaxArray([-1.5375282, -0.5970201, -2.272839 ,  3.233081 , -0.2738593],            dtype=float32)

For more details, please see the Tensors tutorial.

3. JAX transformations on class objects.#

OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math extends JAX transformations to Python classess.

Example 1: JIT compilation performed on class objects.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr1 = LogisticRegression(num_dim)

%timeit lr1(points, labels)
255 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr2 = bm.jit(LogisticRegression(num_dim))

%timeit lr2(points, labels)
162 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Example 2: Autograd performed on variables of a class object.

class Linear(bp.Base):
  def __init__(self, num_hidden, num_input, **kwargs):
    super(Linear, self).__init__(**kwargs)

    # parameters
    self.num_input = num_input
    self.num_hidden = num_hidden

    # variables
    self.w = bm.random.random((num_input, num_hidden))
    self.b = bm.zeros((num_hidden,))

  def __call__(self, x):
    r = x @ self.w + self.b
    return r.mean()
l = Linear(num_hidden=3, num_input=2)
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))
(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
              [0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32),
 DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32))

What is the difference between brainpy.math and other frameworks?#

brainpy.math is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.

However, there are important differences between brainpy.math and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary, brainpy.math allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called Objax which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming.