brainpy.math Overview
Contents
brainpy.math
Overview#
The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!
Excellent JIT compilers such as JAX and Numba are provided in Python. While they are designed to work only on pure Python functions, most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on class
in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.
In order to provide a platform can satisfy the need for brain dynamics programming, we provide the brainpy.math module.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
import numpy as np
Why use brainpy.math
?#
Specifically, brainpy.math
makes the following contributions:
1. Numpy-like ndarray.#
Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray
structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y
. To overcome these drawbacks, brainpy.math
provides JaxArray
that can be used in the same way as numpy ndarray.
# ndarray in "numpy"
a = np.arange(5)
a
array([0, 1, 2, 3, 4])
a[0] += 5
a
array([5, 1, 2, 3, 4])
# ndarray in "brainpy.math"
b = bm.arange(5)
b
JaxArray([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
JaxArray([5, 1, 2, 3, 4], dtype=int32)
For more details, please see the Tensors tutorial.
2. Numpy-like random sampling.#
JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math
provides brainpy.math.random
for random sampling just like the numpy.random
module. For example:
# random sampling in "numpy"
np.random.seed(12345)
np.random.random(5)
array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])
np.random.normal(0., 2., 5)
array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])
# random sampling in "brainpy.math.random"
bm.random.seed(12345)
bm.random.random(5)
JaxArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ], dtype=float32)
bm.random.normal(0., 2., 5)
JaxArray([-1.5375282, -0.5970201, -2.272839 , 3.233081 , -0.2738593], dtype=float32)
For more details, please see the Tensors tutorial.
3. JAX transformations on class objects.#
OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math
extends JAX transformations to Python classess.
Example 1: JIT compilation performed on class objects.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr1 = LogisticRegression(num_dim)
%timeit lr1(points, labels)
255 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr2 = bm.jit(LogisticRegression(num_dim))
%timeit lr2(points, labels)
162 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Example 2: Autograd performed on variables of a class object.
class Linear(bp.Base):
def __init__(self, num_hidden, num_input, **kwargs):
super(Linear, self).__init__(**kwargs)
# parameters
self.num_input = num_input
self.num_hidden = num_hidden
# variables
self.w = bm.random.random((num_input, num_hidden))
self.b = bm.zeros((num_hidden,))
def __call__(self, x):
r = x @ self.w + self.b
return r.mean()
l = Linear(num_hidden=3, num_input=2)
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))
(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
[0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32),
DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32))
What is the difference between brainpy.math
and other frameworks?#
brainpy.math
is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.
However, there are important differences between brainpy.math
and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary, brainpy.math
allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called Objax which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming.