The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!
Excellent JIT compilers such as JAX and Numba are provided in Python. While they are designed to work only on pure Python functions, most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on
class in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.
In order to provide a platform can satisfy the need for brain dynamics programming, we provide the brainpy.math module.
import brainpy as bp import brainpy.math as bm bp.math.set_platform('cpu')
import numpy as np
brainpy.math makes the following contributions:
1. Numpy-like ndarray.#
Python users are familiar with NumPy, especially its ndarray. JAX has similar
ndarray structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like
x[i] += y. To overcome these drawbacks,
JaxArray that can be used in the same way as numpy ndarray.
# ndarray in "numpy" a = np.arange(5) a
array([0, 1, 2, 3, 4])
a += 5 a
array([5, 1, 2, 3, 4])
# ndarray in "brainpy.math" b = bm.arange(5) b
JaxArray([0, 1, 2, 3, 4], dtype=int32)
b += 5 b
JaxArray([5, 1, 2, 3, 4], dtype=int32)
For more details, please see the Tensors tutorial.
2. Numpy-like random sampling.#
JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience,
brainpy.math.random for random sampling just like the
numpy.random module. For example:
# random sampling in "numpy" np.random.seed(12345)
array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])
np.random.normal(0., 2., 5)
array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])
# random sampling in "brainpy.math.random" bm.random.seed(12345)
JaxArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ], dtype=float32)
bm.random.normal(0., 2., 5)
JaxArray([-1.5375282, -0.5970201, -2.272839 , 3.233081 , -0.2738593], dtype=float32)
For more details, please see the Tensors tutorial.
3. JAX transformations on class objects.#
OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming,
brainpy.math extends JAX transformations to Python classess.
Example 1: JIT compilation performed on class objects.
class LogisticRegression(bp.Base): def __init__(self, dimension): super(LogisticRegression, self).__init__() # parameters self.dimension = dimension # variables self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3) def __call__(self, X, Y): u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X) self.w[:] = self.w - u num_dim, num_points = 10, 20000000 points = bm.random.random((num_points, num_dim)) labels = bm.random.random(num_points)
lr1 = LogisticRegression(num_dim) %timeit lr1(points, labels)
255 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr2 = bm.jit(LogisticRegression(num_dim)) %timeit lr2(points, labels)
162 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Example 2: Autograd performed on variables of a class object.
class Linear(bp.Base): def __init__(self, num_hidden, num_input, **kwargs): super(Linear, self).__init__(**kwargs) # parameters self.num_input = num_input self.num_hidden = num_hidden # variables self.w = bm.random.random((num_input, num_hidden)) self.b = bm.zeros((num_hidden,)) def __call__(self, x): r = x @ self.w + self.b return r.mean()
l = Linear(num_hidden=3, num_input=2)
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))
(DeviceArray([[0.14844148, 0.14844148, 0.14844148], [0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32), DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32))
What is the difference between
brainpy.math and other frameworks?#
brainpy.math is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.
However, there are important differences between
brainpy.math and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary,
brainpy.math allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called Objax which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming.