BrainPy documentation
BrainPy is a highly flexible and extensible framework targeting on the high-performance brain modeling. Among its key ingredients, BrainPy supports:
JIT compilation for functions and class objects.
Numerical solvers for ODEs, SDEs and others.
Dynamics simulation tools for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more.
Dynamics analysis tools for differential equations, including phase plane analysis, bifurcation analysis, and linearization analysis.
Seamless integration with deep learning models, but has the high speed acceleration because of JIT compilation.
And more ……
Note
Comprehensive examples of BrainPy please see:
BrainModels: https://github.com/PKU-NIP-Lab/BrainModels
BrainPyExamples: https://brainpy-examples.readthedocs.io/
Installation
BrainPy
is designed to run on across-platforms, including Windows,
GNU/Linux and OSX. It only relies on Python libraries.
Installation with pip
You can install BrainPy
from the pypi.
To do so, use:
pip install brain-py
If you try to update the BrainPy version, you can use
pip install -U brain-py
If you want to install the pre-release version (the latest development version) of BrainPy, you can use:
pip install --pre brain-py
Installation from source
If you decide not to use conda
or pip
, you can install BrainPy
from
GitHub,
or OpenI.
To do so, use:
pip install git+https://github.com/PKU-NIP-Lab/BrainPy
# or
pip install git+https://git.openi.org.cn/OpenI/BrainPy
Package Dependency
In order to make BrainPy work normally, users should install several dependent Python packages.
NumPy & Matplotlib
The basic function of BrainPy
only relies on NumPy
and Matplotlib. Install these two packages is very
easy, just using pip
or conda
:
pip install numpy matplotlib
# or
conda install numpy matplotlib
JAX
We highly recommend you to install JAX. JAX is a high-performance JIT compiler which enables users run Python code on CPU, GPU, or TPU devices. Most functionalities of BrainPy is based on JAX.
Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of JAX for Linux and macOS systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html .
To install a CPU-only version of JAX, you can run
pip install --upgrade "jax[cpu]"
If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Alternatively, you can download the preferred release “.whl” file, and install it via pip
:
pip install xxxx.whl
For Windows users, JAX can be installed by the following methods:
Method 1: For Windows 10+ system, you can Windows Subsystem for Linux (WSL). The installation guide can be found in WSL Installation Guide for Windows 10. Then, you can install JAX in WSL just like the installation step in Linux.
Method 2: There are several community supported Windows build for jax, please refer to the github link for more details: https://github.com/cloudhan/jax-windows-builder . Simply speaking, you can run:
# for only CPU
pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
# for GPU support
pip install <downloaded jaxlib>
Method 3: You can also build JAX from source.
Numba
Numba is also an excellent JIT compiler, which can accelerate your Python codes to approach the speeds of C or FORTRAN. Numba works best with NumPy. Many BrainPy modules rely on Numba for speed acceleration, such like connectivity, simulation, analysis, measurements, etc. Numba is also a suitable framework for the computation of sparse synaptic connections commonly used in the computational neuroscience project.
Numba is a cross-platform package which can be installed on Windows, Linux, and macOS. Install Numba is a piece of cake. You just need type the following commands in you terminal:
pip install numba
# or
conda install numba
SymPy
In BrainPy, several modules need the symbolic inference by SymPy. For example, Exponential Euler numerical solver needs SymPy to compute the linear part of your defined Python codes, phase plane and bifurcation analysis in dynamics analysis module needs symbolic computation from SymPy. Therefore, we highly recommend you to install sympy, just typing
pip install sympy
# or
conda install sympy
JIT Compilation
The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!
Excellent JIT compilers such as JAX and Numba are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it’s hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on class
in Python will make your coding more readable, controlable, flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.
Here, in BrainPy, we provide JIT compilation interface for class objects, built on the top of JAX and Numba. In this section, we will talk about this.
import brainpy as bp
import brainpy.math as bm
JIT in Numba and JAX
Numba is specialized to optimize your native NumPy codes, including NumPy arrays, loops and condition controls, etc. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can just-in-time compile your native Python loops (for
or while
syntaxs) and condition controls (if ... else ...
). This means that it supports your intutive Python programming.
However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn’t support one code runs on multiple devices
. Same code cannot run on GPU targets.
JAX is a rising-star JIT compiler in Python scientific computing. It uses XLA to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance.
However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on WSL or compile JAX from source. Further, the coding in JAX is not very intutive. For example,
Doesn’t support in-place mutating updates of arrays, like
x[i] += y
, instead you should usex = jax.ops.index_update(x, i, y)
Doesn’t support JIT compilation of your native loops and conditions, like
arr = np.zeros(5)
for i in range(arr.shape[0]):
arr[i] += 2.
if i % 2 == 0:
arr[i] += 1.
instead you should use
arr = np.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
return jax.lax.cond(i % 2 == 0,
arr1,
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
arr1,
lambda arr1: arr1)
arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)
What’s more, both frameworks have poor support on class objects.
JIT compilation in BrainPy
In order to obtain an intutive, flexible and high-performance framework for brain modeling, in BrainPy, we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it).
Specifically, we provide BrainPy math module for
flexible switch between NumPy (Numba) and JAX backends
unified numpy-like array operations
unified
ndarray
data structure which supports in-place updateunified
random
APIspowerful
jit()
compilation which supports functions and class objects both
Backend Switch
To switch different backend, you can use:
# switch to NumPy backend
bm.use_backend('numpy')
bm.get_backend_name()
'numpy'
# switch to JAX backend
bm.use_backend('jax')
bm.get_backend_name()
'jax'
In BrainPy, “numpy” and “jax” backends are interchangeable. Both backends nearly have the same APIs, and same codes can run on both backends.
Math Operations
The APIs in brainpy.math
module in each backend is much similar to APIs in original numpy
. The detailed comparison please see the Comparison Table.
For example, the array creation functions,
bm.zeros((10, 3))
JaxArray(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))
bm.arange(10)
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
x = bm.array([[1,2], [3,4]])
x
JaxArray(DeviceArray([[1, 2],
[3, 4]], dtype=int32))
The array manipulation functions:
bm.max(x)
DeviceArray(4, dtype=int32)
bm.repeat(x, 2)
JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))
bm.repeat(x, 2, axis=1)
JaxArray(DeviceArray([[1, 1, 2, 2],
[3, 3, 4, 4]], dtype=int32))
The random numbers generation functions:
bm.random.random((3, 5))
JaxArray(DeviceArray([[0.77751863, 0.45367956, 0.09705102, 0.35475922, 0.6678729 ],
[0.81603515, 0.83256996, 0.29231536, 0.8294617 , 0.41171193],
[0.37354684, 0.8089644 , 0.4921714 , 0.93098676, 0.4895897 ]], dtype=float32))
y = bm.random.normal(loc=0.0, scale=2.0, size=(2, 5))
y
JaxArray(DeviceArray([[ 0.24033605, 1.6801527 , 1.6716577 , 0.19926837,
-2.3778987 ],
[ 0.58184516, 1.3625289 , -2.3364332 , 2.082281 ,
0.23409791]], dtype=float32))
The linear algebra functions:
bm.dot(x, y)
JaxArray(DeviceArray([[ 1.4040264, 4.4052105, -3.0012088, 4.3638306, -1.9097029],
[ 3.0483887, 10.490574 , -4.3307595, 8.926929 , -6.1973042]], dtype=float32))
bm.linalg.eig(x)
(JaxArray(DeviceArray([-0.37228107+0.j, 5.3722816 +0.j], dtype=complex64)),
JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
[ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))
The discrete fourier transform functions:
bm.fft.fft(bm.exp(2j * bm.pi * bm.arange(8) / 8))
JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j, 8.0000000e+00+4.8023384e-07j,
-3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
-3.8941437e-07-2.0663207e-07j, 2.3841858e-07-1.9411573e-07j,
3.8941437e-07-2.0663207e-07j, 1.6858739e-07+3.1786506e-08j], dtype=complex64))
bm.fft.ifft(bm.array([0, 4, 0, 0]))
JaxArray(DeviceArray([ 1.+0.j, 0.+1.j, -1.+0.j, 0.-1.j], dtype=complex64))
The full list of API implementation please see the Comparison Table.
JIT for Functions
To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform your codes into machine codes.
Take the pure functions as the example. Here we try to implement a function of Gaussian Error Linear Unit,
bm.use_backend('jax')
def gelu(x):
sqrt = bm.sqrt(2 / bm.pi)
cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
y = x * cdf
return y
Let’s first try the running speed without JIT.
# jax backend, without JIT
x = bm.random.random(100000)
%timeit gelu(x)
332 µs ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
The function after JIT can significantly speed up.
# jax backend, with JIT
gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66.6 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
After siwtching to the ‘numpy’ backend, the result also applies.
bm.use_backend('numpy')
# numpy backend, without JIT
x = bm.random.random(100000)
%timeit gelu(x)
3.52 ms ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# numpy backend, with JIT
gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
1.91 ms ± 55.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
JIT for Objects
However, in BrainPy, JIT compilation can be carried on the class objects. Specifically, any instance of brainpy.Base object can be just-in-time compiled into machine codes.
Let’s try a simple example, which trains a Logistic regression classifier.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
num_dim, num_points = 10, 20000000
In this example, we want to train the model weights (self.w
) again and again. So, we mark the weights as bm.Variable
.
import time
def benckmark(model, points, labels, num_iter=30, name=''):
t0 = time.time()
for i in range(num_iter):
model(points, labels)
print(f'{name} used time {time.time() - t0} s')
Let’s try to benchmark the ‘numpy’ backend.
bm.use_backend('numpy')
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# numpy backend, without JIT
lr1 = LogisticRegression(num_dim)
benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 18.891679763793945 s
# numpy backend, with JIT
lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)
benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 12.473472833633423 s
Same results can be obtained in ‘jax’ backend.
bm.use_backend('jax')
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# jax backend, without JIT
lr1 = LogisticRegression(num_dim)
benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 6.245944976806641 s
# jax backend, with JIT
lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)
benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 4.174307346343994 s
As you can see, the JIT function showes its speed advantage.
However, what’s worth noting here is that:
The dynamically changed variable (weight
w
) is marked as a bm.Variable (in__init__()
function).The variable
w
is in-place updated with[:]
indexing (in__call__()
function).
The above two things are all constraints in BrainPy for the JIT compilation of class objects. Other operations and coding styles are the same with the normal object-oriented programming in Python.
Mechanism of JIT in NumPy backend
bm.use_backend('numpy')
So, why must we in-place update the dynamically changed variables?
First of all, in the compilation phase, a
self.
accessed variable which is not an instance ofbm.Variable
will be compiled as a static constant. For example,self.a = 1.
will be compiled as a constant1.
. If you try to change the value ofself.a
, it will not work.
class Demo1(bp.Base):
def __init__(self):
super(Demo1, self).__init__()
self.a = 1.
def update(self, b):
self.a = b
d1 = Demo1()
bm.jit(d1.update)(2.)
print(d1.a)
1.0
Second, all the variables you want to change during the function call must be labeled as
bm.Variable
. Then during the JIT compilation period, these variables will be recompiled as arguments of the jitted functions.
class Demo2(bp.Base):
def __init__(self):
super(Demo2, self).__init__()
self.a = bm.Variable(1.)
def update(self, b):
self.a = b
bm.jit(Demo2().update, show_code=True)
The recompiled function:
-------------------------
def update(b, Demo20_a=None):
Demo20_a = b
The namespace of the above function:
{}
The recompiled function:
-------------------------
def new_update(b):
update(b,
Demo20_a=Demo20.a.value,)
The namespace of the above function:
{'Demo20': <__main__.Demo2 object at 0x7fa9b055e2e0>,
'update': CPUDispatcher(<function update at 0x7fa9d01ff160>)}
<function new_update(b)>
The original Demo2.update
function is recompiled as update()
function, with the dynamical variable a
compiled as an argument Demo20_a
. Then, during the functional call (in the new_update()
function), Demo20.a.value
is passed to Demo20_a
for the jitted update()
function.
Third, as you can notice in the above source code of the recompiled function, the recompiled variable
Demo20_a
does not return. This means once the function finished its running, the computed value will disappear. Therefore, the dynamically changed variables must be in-place updated to hold their updated values.
class Demo3(bp.Base):
def __init__(self):
super(Demo3, self).__init__()
self.a = bm.Variable(1.)
def update(self, b):
self.a[...] = b
d3 = Demo3()
bm.jit(d3.update)(2.)
d3.a
Variable(2.)
The above simple demonstrations illustrate the core mechanism of the JIT compilation in NumPy backend. bm.jit() in NumPy backend can recursively compile your class objects. So, please try your models, and run it under the JIT accelerations.
However, the mechanism of JIT compilation of JAX backend is quite different. We will detail this in the upcoming tutorials.
In-place operators
In the next, it is important to answer: what are in-place operators?
v = bm.arange(10)
id(v)
140366785129408
Actually, in-place operators include the following operations:
Indexing and slicing. Like (More details please refer to Array Objects Indexing)
Index:
v[i] = a
Slice:
v[i:j] = b
Slice the specific values:
v[[1, 3]] = c
Slice all values,
v[:] = d
,v[...] = e
v[0] = 1
id(v)
140366785129408
v[1: 2] = 1
id(v)
140366785129408
v[[1, 3]] = 2
id(v)
140366785129408
v[:] = 0
id(v)
140366785129408
v[...] = bm.arange(10)
id(v)
140366785129408
Augmented assignment. All augmented assignment are in-place operations, which include
+=
(add)-=
(subtract)/=
(divide)*=
(multiply)//=
(floor divide)%=
(modulo)**=
(power)&=
(and)|=
(or)^=
(xor)<<=
(left shift)>>=
(right shift)
v += 1
id(v)
140366785129408
v *= 2
id(v)
140366785129408
v |= bm.random.randint(0, 2, 10)
id (v)
140366785129408
v **= 2.
id(v)
140366785129408
v >>= 2
id(v)
140366785129408
More advanced usage please see our forthcoming tutorials.
Dynamics Introduction
What I cannot create, I do not understand. — Richard Feynman
Brain is a complex dynamical system. In order to simulate it, we provide brainpy.DynamicalSystem. brainpy.DynamicalSystem
can be used to define any brain objects which have dynamics. Various children classes are implemented to model these brain elements, such like: brainpy.Channel for neuron channels, brainpy.NeuGroup for neuron groups, brainpy.TwoEndConn for synaptic connections, brainpy.Network for networks, etc. Arbitrary composition of these objects is also an instance of brainpy.DynamicalSystem
. Therefore, brainpy.DynamicalSystem
is the universal language to define dynamical models in BrainPy.
import brainpy as bp
import brainpy.math as bm
brainpy.DynamicalSystem
In this section, let’s try to understand the mechanism and the function of brainpy.DynamicalSystem
.
What is DynamicalSystem?
First, what can be defined as DynamicalSystem?
Intuitively, a dynamical system
is a system which has the time-dependent state.
Mathematically, it can be expressed as
$$ \dot{X} = f(X, t) $$
where $X$ is the state of the system, $t$ is the time, and $f$ is a function describes the time dependence of the system state.
Alternatively, the evolution of the system along the time can be given by
$$ X(t+dt) = F\left(X(t), t, dt\right) $$
where $dt$ is the time step, and $F$ is the evolution rule to update the system’s state.
Accordingly, in BrainPy, any subclass of brainpy.DynamicalSystem
must implement this updating rule in the update function (def update(self, _t, _dt)
). One dynamical system may have multiple updating rules, therefore, users can define multiple update functions. All updating functions are wrapped into an inner data structure self.steps (a Python dictionary specifies the name and the function of updating rules). Let’s take a look.
class FitzHughNagumoModel(bp.DynamicalSystem):
def __init__(self, a=0.8, b=0.7, tau=12.5, **kwargs):
super(FitzHughNagumoModel, self).__init__(**kwargs)
# parameters
self.a = a
self.b = b
self.tau = tau
# variables
self.v = bm.Variable([0.])
self.w = bm.Variable([0.])
self.I = bm.Variable([0.])
def update(self, _t, _dt):
# _t : the current time, the system keyword
# _dt : the time step, the system keyword
self.w += (self.v + self.a - self.b * self.w) / self.tau * _dt
self.v += (self.v - self.v ** 3 / 3 - self.w + self.I) * _dt
self.I[:] = 0.
Here, we have defined a dynamical system called FitzHugh–Nagumo neuron model, whose dynamics is given by:
$$ {\dot {v}}=v-{\frac {v^{3}}{3}}-w+I, \ \tau {\dot {w}}=v+a-bw. $$
By using the Euler method, this system can be updated by the following rule:
$$ \begin{aligned} v(t+dt) &= v(t) + [v(t)-{v(t)^{3}/3}-w(t)+RI] * dt, \ w(t + dt) &= w(t) + [v(t) + a - b w(t)] * dt. \end{aligned} $$
We can inspect all update functions in the model by xxx.steps
.
fnh = FitzHughNagumoModel()
fnh.steps # all update functions
{'update': <bound method FitzHughNagumoModel.update of <__main__.FitzHughNagumoModel object at 0x7f7f4c7ceaf0>>}
Why to use DynamicalSystem?
So, why should I define my dynamical system as brainpy.DynamicalSystem?
There are several benefits.
brainpy.DynamicalSystem
has a systematic naming system.
First, every instance of DynamicalSystem
has its unique name.
fnh.name # name for "fnh" instance
'FitzHughNagumoModel0'
# every instance has its unique name
for _ in range(5):
print(FitzHughNagumoModel().name)
FitzHughNagumoModel1
FitzHughNagumoModel2
FitzHughNagumoModel3
FitzHughNagumoModel4
FitzHughNagumoModel5
# the model name can be specified by yourself
fnh2 = FitzHughNagumoModel(name='X')
fnh2.name
'X'
# same name will cause error
try:
FitzHughNagumoModel(name='X')
except bp.errors.UniqueNameError as e:
print(e)
In BrainPy, each object should have a unique name. However, we detect that <__main__.FitzHughNagumoModel object at 0x7f7f4c7f3850> has a used name "X".
Second, variables, children nodes, etc. inside an instance can be easily accessed by the absolute or relative path.
# All variables can be acessed by
# 1). the absolute path
fnh2.vars()
{'X.I': Variable([0.]), 'X.v': Variable([0.]), 'X.w': Variable([0.])}
# 2). or, the relative path
fnh2.vars(method='relative')
{'I': Variable([0.]), 'v': Variable([0.]), 'w': Variable([0.])}
If we wrap many instances into a container: brainpy.Network
, variables and nodes can also be accessed by absolute or relative path.
fnh_net = bp.Network(f1=fnh, f2=fnh2)
# absolute access of variables
fnh_net.vars()
{'FitzHughNagumoModel0.I': Variable([0.]),
'FitzHughNagumoModel0.v': Variable([0.]),
'FitzHughNagumoModel0.w': Variable([0.]),
'X.I': Variable([0.]),
'X.v': Variable([0.]),
'X.w': Variable([0.])}
# relative access of variables
fnh_net.vars(method='relative')
{'f1.I': Variable([0.]),
'f1.v': Variable([0.]),
'f1.w': Variable([0.]),
'f2.I': Variable([0.]),
'f2.v': Variable([0.]),
'f2.w': Variable([0.])}
# absolute access of nodes
fnh_net.nodes()
{'FitzHughNagumoModel0': <__main__.FitzHughNagumoModel at 0x7f7f4c7ceaf0>,
'X': <__main__.FitzHughNagumoModel at 0x7f7f4c7f3700>,
'Network0': <brainpy.simulation.brainobjects.network.Network at 0x7f7f4c7f3340>}
# relative access of nodes
fnh_net.nodes(method='relative')
{'': <brainpy.simulation.brainobjects.network.Network at 0x7f7f4c7f3340>,
'f1': <__main__.FitzHughNagumoModel at 0x7f7f4c7ceaf0>,
'f2': <__main__.FitzHughNagumoModel at 0x7f7f4c7f3700>}
Automatic monitors. Any instance of
brainpy.DynamicalSystem
can call.run(duration)
. During running, a brainpy.Monitor inside the dynamical system (xxx.mon
) can be used to automatically monitor the history values of the interested variables. Details please see the tutorial of Monitors and Inputs.
# in "fnh3" instance, we try to monitor "v", "w", and "I" variables
fnh3 = FitzHughNagumoModel(monitors=['v', 'w', 'I'])
# in "fnh4" instance, we only monitor "v" variable
fnh4 = FitzHughNagumoModel(monitors=['v'], name='Y')
Convenient input operations. During the model running, users can specify the inputs for each model component, with the format of
(target, value, [type, operation])
(the details please see the tutorial of [Monitors and Inputs]).(../tutorial_simulation/monitors_and_inputs.ipynb).The
target
is the variable accessed by the absolute or relative path. Absolute path access will be very useful in a huge network model.The default input
type
is “fix”, means thevalue
must be a constant scalar or array over time. “iter” type of input is also allowed, which means thevalue
can be an iterable objects (arrays, or iterable functions, etc.).The default
operation
is+
, which means the inputvalue
will be added to thetarget
. Allowed operations include+
,-
,*
,/
, and=
.
%matplotlib inline
import matplotlib.pyplot as plt
bm.set_dt(dt=0.01)
fnh3.run(duration=100,
# relative path to access variable 'I'
inputs=('I', 1.5))
plt.plot(fnh3.mon.ts, fnh3.mon.v, label='v')
plt.plot(fnh3.mon.ts, fnh3.mon.w, label='w')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c761e80>

inputs = bm.linspace(1., 2., 10000)
fnh4.run(duration=100,
inputs=('Y.I', # specify 'target' by the absolute path access
inputs, # specify 'value' with an iterable array
'iter')) # "iter" input 'type' must be explicitly specified
plt.plot(fnh4.mon.ts, fnh4.mon.v, label='v')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c7f3b80>

def inputs():
for i in range(10000):
yield 1.5
fnh4.run(duration=100,
inputs=('Y.I', # specify 'target' by the absolute path access
inputs(), # specify 'value' with an iterable function
'iter')) # "iter" input 'type' must be explicitly specified
plt.plot(fnh4.mon.ts, fnh4.mon.v, label='v')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c082d60>

brainpy.DynamicalSystem
is a subclass of brainpy.Base, therefore, any instance ofbrainpy.DynamicalSystem
can be just-in-time compiled into efficient machine codes targeting on CPUs, GPUs, or TPUs.
fnh3_jit = bm.jit(fnh3)
fnh3_jit.run(duration=100, inputs=('I', 1.5))
plt.plot(fnh3_jit.mon.ts, fnh3_jit.mon.v, label='v')
plt.plot(fnh3_jit.mon.ts, fnh3_jit.mon.w, label='w')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f2c595b50>

brainpy.DynamicalSystem
can be combined arbitrarily. Any composed system can also benefit from the above convenient interfaces.
# compose two FitzHughNagumoModel instances into a Network
net2 = bp.Network(f1=fnh3, f2=fnh4, monitors=['f1.v', 'Y.v'])
net2.run(100, inputs=[
('f1.I', 1.5), # relative access variable "I" in 'fnh3'
('Y.I', 1.0), # absolute access variable "I" in 'fnh4'
])
plt.plot(net2.mon.ts, net2.mon['f1.v'], label='v1')
plt.plot(net2.mon.ts, net2.mon['Y.v'], label='v2')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f2c2c6ee0>

In next sections, we will illustrate how to define common brain objects (specifically, the neuron, the synapse and the network) by subclasses of brainpy.DynamicalSystem
.
brainpy.NeuGroup
brainpy.NeuGroup
is used for neuron group modeling. User-defined neuron group models should inherit from the brainpy.NeuGroup
. Let’s take the leaky integrate-and-fire (LIF) model as the illustrated example.
LIF neuron model
The formal equations of a LIF model is given by:
$$ \begin{aligned} \tau_m \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \quad\quad (1) \ \text{after} , V(t) \gt V_{th}, V(t) =V_{rest} , \text{last} , \tau_{ref} , \text{ms} \quad\quad (2) \end{aligned} $$
where $V$ is the membrane potential, $V_{rest}$ is the rest membrane potential, $V_{th}$ is the spike threshold, $\tau_m$ is the time constant, $\tau_{ref}$ is the refractory time period, and $I$ is the time-variant synaptic inputs.
The above two equations mean that: when the membrane potential $V$ is below $V_{th}$, the model integrates $V$ with the equation (1); once $V > V_{th}$, according to equation (2), we will reset the membrane potential to $V_{rest}$, and the model enters into the refractory period which lasts $\tau_{ref}$ ms. In the refractory period, the membrane potential $V$ will no longer change.
Let’s start to code this LIF neuron model. First, we will define the following items to store the neuron state:
V
: The membrane potential.input
: The synaptic input.spike
: Whether produce a spike.refractory
: Whether the neuron is in the refractory state.t_last_spike
: To record the last spike time.
Based on these states, the updating logic of LIF model from the current time $t$ to the next time $t+dt$ will be coded as:
class LIF(bp.NeuGroup):
def __init__(self, size, t_refractory=1., V_rest=0., V_reset=-5.,
V_th=20., R=1., tau=10., **kwargs):
super(LIF, self).__init__(size=size, **kwargs)
# parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.R = R
self.tau = tau
self.t_refractory = t_refractory
# variables
self.V = bm.Variable(bm.random.randn(self.num) * 5. + V_reset)
self.input = bm.Variable(bm.zeros(self.num))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
self.integral = bp.odeint(f=self.derivative, method='exponential_euler')
def derivative(self, V, t, Iext):
dvdt = (- (V - self.V_rest) + self.R * Iext) / self.tau
return dvdt
def update(self, _t, _dt):
for i in range(self.num):
if _t - self.t_last_spike[i] <= self.t_refractory:
self.refractory[i] = True
self.spike[i] = False
else:
V = self.integral(self.V[i], _t, self.input[i])
if V >= self.V_th:
self.V[i] = self.V_reset
self.t_last_spike[i] = _t
self.spike[i] = True
self.refractory[i] = True
else:
self.V[i] = V
self.spike[i] = False
self.refractory[i] = False
self.input[i] = 0.
That’s all, we have coded a LIF neuron model. Note, here we define equation (1) by brainpy.odeint as an ODEIntegrator. We will illustrate how to define ODE numerical integration in the Numerical Solvers for ODEs tutorial.
Each NeuGroup has a powerful function: .run()
. In this function, it receives the following arguments:
duration
: Specify the simulation duration. It can be a tuple with(start time, end time)
, or a int to specify the durationlength
(then the default start time is0
).inputs
: Specify the inputs for each model component. With the format of(target, value, [type, operation])
. Details please see the tutorial of Monitors and Inputs.report
: a float to specify the progress percent to report. “0” (default) means doesn’t report running progress.
Now, let’s run the defined model.
group = LIF(100, monitors=['V'])
group.run(duration=200., inputs=('input', 26.), report=0.5)
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
Compilation used 0.0004 s.
Start running ...
Run 50.0% used 3.319 s.
Run 100.0% used 6.479 s.
Simulation is done in 6.480 s.

group.run(duration=(200, 400.), report=0.2)
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
Compilation used 0.0003 s.
Start running ...
Run 20.0% used 1.364 s.
Run 40.0% used 2.711 s.
Run 60.0% used 4.024 s.
Run 80.0% used 5.299 s.
Run 100.0% used 6.589 s.
Simulation is done in 6.590 s.

In the model definition, BrainPy endows you with the fully data/logic flow control. You can define models with any data you need and any logic you want. There are little limitations/constrains on your customization.
you should “super()” initialize the
brainpy.NeuGroup
with the keyword of the groupsize
.you should define the
update
function.
brainpy.TwoEndConn
For synaptic computations, BrainPy provides brainpy.TwoEndConn to help you construct the connections between pre-synaptic and post-synaptic neuron groups, and provides brainpy.connect.TwoEndConnector for synaptic projections between pre- and post-synaptic groups.
brainpy.TwoEndConn
can help to construct automatic delay in synaptic computations. The modeling of synapses usually includes a delay time (typically 0.3–0.5 ms) required for a neurotransmitter to be released from a presynaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane. BrainPy provides register_constant_dely() for automatic state delay.
brainpy.connect.TwoEndConnector
provides convenient interface for connectivity structure construction. Various synaptic structures, like pre_ids, post_ids, conn_mat, pre2post, post2pre, pre2syn, post2syn, pre_slice, and post_slice can be constructed. Users just need to require such data structures by calling connector.require('pre_ids', 'post_ids', ...)
. We will detail this function in Efficient Synaptic Connections.
Here, let’s illustrate how to use brainpy.TwoEndConn
with the Exponential synapse model.
Exponential synapse model
Exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant $\tau_{decay}$. Its dynamics is given by:
$$ \frac{d s}{d t} = -\frac{s}{\tau_{decay}}+\sum_{k} \delta(t-D-t^{k}) $$
where $s$ is the synaptic state, $t^{k}$ is the spike time of the pre-synaptic neuron, and $D$ is the synaptic delay.
Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form
$$ I_{syn}(t) = g_{max} s \left( V(t)-E \right) $$
where $E$ is the reversal potential of the synapse, $V$ is the post-synaptic membrane potential, $g_{max}$ is the maximum synaptic conductance.
So, let’s try to implement this synapse model.
class Exponential(bp.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., **kwargs):
super(Exponential, self).__init__(pre=pre, post=post, conn=conn, **kwargs)
# parameters
self.g_max = g_max
self.E = E
self.tau = tau
self.delay = delay
# connections
self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
self.num = len(self.pre_ids)
# variables
self.s = bm.Variable(bm.zeros(self.num))
self.pre_spike = self.register_constant_delay('ps', size=self.pre.num, delay=delay)
# functions
self.integral = bp.odeint(self.derivative, method='exponential_euler')
def derivative(self, s, t):
dsdt = - s / self.tau
return dsdt
def update(self, _t, _dt):
# P1: push the pre-synaptic spikes into the delay
self.pre_spike.push(self.pre.spike)
# P2: pull the delayed pre-synaptic spikes
delayed_pre_spike = self.pre_spike.pull()
# P3: update the synatic state
self.s[:] = self.integral(self.s, _t)
for syn_i in range(self.num):
pre_i, post_i = self.pre_ids[syn_i], self.post_ids[syn_i]
# P4: whether pre-synaptic neuron generates a spike
if delayed_pre_spike[pre_i]:
self.s[syn_i] += 1.
# P5: output the synapse current onto the post-synaptic neuron
self.post.input[post_i] += self.g_max * self.s[syn_i] * (self.E - self.post.V[post_i])
Here, we create a synaptic model by using the synaptic structures of pre_ids and post_ids , looks like this:

The pre-synaptic neuron index (pre ids
) is shown in the green color. The post-synaptic neuron index (post ids
) is shown in the red color. Each pair of (pre id, post id) denotes a synapse between two neuron groups. Each synapse connection also has a unique index, called the synapse index, which is shown in the third row (syn ids
).
brainpy.Network
In above, we have illustrated how to define neurons by brainpy.NeuGroup and synapses by brainpy.TwoEndConn. In the next, we talk about how to create a network by using brainpy.Network.
E/I balanced network
Here, we try to create a E/I balanced network according to the reference [1].
This EI network has 4000 leaky integrate-and-fire neurons. Each integrate-and-fire neuron is characterized by a time constant, $\tau$ = 20 ms, and a resting membrane potential, $V_{rest}$ = -60 mV. Whenever the membrane potential crosses a spiking threshold of -50 mV, an action potential is generated and the membrane potential is reset to the resting potential, where it remains clamped for a 5 ms refractory period.
num_exc = 3200
num_inh = 800
E = LIF(num_exc, tau=20, V_th=-50, V_rest=-60, V_reset=-60, t_refractory=5., monitors=['spike'])
I = LIF(num_inh, tau=20, V_th=-50, V_rest=-60, V_reset=-60, t_refractory=5.)
E.V[:] = bm.random.randn(num_exc) * 5. - 55.
I.V[:] = bm.random.randn(num_inh) * 5. - 55.
The ratio of the excitatory and inhibitory neurons are 4:1. The neurons connect to each other randomly with a connection probability of 2%.
The kinetics of the synapse is governed by the exponential synapse model shown above. Specifically, synaptic time constants $\tau_e$ = 5 ms for excitatory synapses and $\tau_i$ = 10 ms for inhibitory synapse. The maximum synaptic conductance is $0.6$ for the excitatory synapse and $6.7$ for the inhibitory synapse. Reversal potentials are $E_e$ = 0 mV and $E_i$ = -80 mV.
E2E = Exponential(E, E, bp.connect.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5)
E2I = Exponential(E, I, bp.connect.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5)
I2E = Exponential(I, E, bp.connect.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10)
I2I = Exponential(I, I, bp.connect.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10)
After this, we can create a network to wrap these object together.
net = bp.Network(E2E, E2I, I2I, I2E, E=E, I=I)
net = bm.jit(net)
net.run(100., inputs=[('E.input', 20.), ('I.input', 20.)], report=0.1)
bp.visualize.raster_plot(E.mon.ts, E.mon.spike, show=True)
Compilation used 4.8391 s.
Start running ...
Run 10.0% used 1.485 s.
Run 20.0% used 2.806 s.
Run 30.0% used 4.120 s.
Run 40.0% used 5.506 s.
Run 50.0% used 6.886 s.
Run 60.0% used 8.214 s.
Run 70.0% used 9.513 s.
Run 80.0% used 10.834 s.
Run 90.0% used 12.197 s.
Run 100.0% used 13.507 s.
Simulation is done in 13.507 s.

References:
[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
Tensors
In this section, we are going to understand:
what is
tensor
?how to create
tensor
?what operations are supported for a
tensor
?
import brainpy.math as bm
What is tensor
?
A tensor is a homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. The dimensions of an array are called axes.
In the following picture, the 1D array ([7, 2, 9, 10]
) only has one axis. That axis has 4 elements in it, so we say it has a shape of (4,)
.
While, the 2D array
[[5.2, 3.0, 4.5],
[9.1, 0.1, 0.3]]
has 2 axes. The first axis has a length of 2, the second axis has a length of 3. So, we say it has a shape of (2, 3)
.
Similarly, the 3D array has 3 axes, with dimensions in each axis is (4, 3, 2)
.
Each tensor has several important attributes:
.ndim: the number of axes (dimensions) of the tensor.
.shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, shape will be
(n,m)
. The length of the shape tuple is therefore the number of axes,ndim
..size: the total number of elements of the tensor. This is equal to the product of the elements of shape.
.dtype: an object describing the type of the elements in the tensor. One can create or specify dtype’s using standard Python types.
In ‘numpy’ backend, the tensor is exactly the same as the tensor in NumPy. For example:
bm.use_backend('numpy')
a = bm.arange(15).reshape((3, 5))
a
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int64')
However, in ‘jax’ backend, we wrap the original jax.numpy.ndarray
, and create a new data structure JaxArray
. However, the attributes and operations are the same with the NumPy tensors. For example:
bm.use_backend('jax')
a = bm.arange(15).reshape((3, 5))
a
JaxArray(DeviceArray([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]], dtype=int32))
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int32')
How to create tensor
?
There are several ways to create tensors. Methods for tensor creation are same under “numpy” and “jax” backends.
array()
, zeros()
and ones()
The basic method is to convert Python sequences into tensors by bm.array()
. For example:
bm.array([2, 3, 4])
JaxArray(DeviceArray([2, 3, 4], dtype=int32))
bm.array([(1.5, 2, 3), (4, 5, 6)])
JaxArray(DeviceArray([[1.5, 2. , 3. ],
[4. , 5. , 6. ]], dtype=float32))
Often, the elements of an array are originally unknown, but its size is known. Therefore, you can use placeholder functions to create tensors, like:
# "bm.zeros()" creates an array full of zeros
bm.zeros((3, 4))
JaxArray(DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32))
# "bm.ones()" creates an array full of ones
bm.ones((3, 4))
JaxArray(DeviceArray([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32))
linspace()
and arange()
Another two commonly used 1D array creation functions are bm.linspace()
and bm.arange()
.
bm.arange()
creates arrays with regularly incrementing values. It receives “start”, “end”, and “step” settings.
# if only one argument "A" are provided, the function will
# recognize the "start = 0", "end = A", and "step = 1" .
bm.arange(10)
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
# if two argument "A, B" are provided, the function will
# recognize the "start = A", "end = B", and "step = 1" .
bm.arange(2, 10, dtype=float)
JaxArray(DeviceArray([2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32))
# if three argument "A, B, C" are provided, the function will
# recognize the "start = A", "end = B", and "step = C" .
bm.arange(2, 3, 0.1)
JaxArray(DeviceArray([2. , 2.1 , 2.1999998, 2.2999997, 2.3999996,
2.4999995, 2.5999994, 2.6999993, 2.7999992, 2.8999991], dtype=float32))
Due to the finite floating point precision, it is generally not possible to predict the number of elements obtained by bm.arange()
. For this reason, it is usually better to use the function bm.linspace()
that receives “start”, “end”, and “num” settings.
bm.linspace(2, 3, 10)
JaxArray(DeviceArray([2. , 2.1111112, 2.2222223, 2.3333333, 2.4444447,
2.5555556, 2.6666665, 2.777778 , 2.8888888, 3. ], dtype=float32))
Random sampling
brainpy.math
module provides convenient random sampling functions. This module contains some simple random data generation methods, some permutation and distribution functions, and random generator functions. Here I just give several examples.
brainpy.math.random.rand(d0, d1, ..., dn)
This function of random module is used to generate random numbers or values in a given shape.
bm.random.rand(5, 2)
JaxArray(DeviceArray([[0.99398685, 0.39656162],
[0.5161425 , 0.81978667],
[0.31676686, 0.083583 ],
[0.16560888, 0.40949285],
[0.43086028, 0.22965682]], dtype=float32))
brainpy.math.random.randn(d0, d1, ..., dn)
This function of random module return a sample from the “standard normal” distribution.
bm.random.randn(5, 2)
JaxArray(DeviceArray([[-0.7701253 , 0.00965391],
[-0.11301948, 0.1409633 ],
[-0.11914475, 0.068143 ],
[ 1.6409276 , 1.3378068 ],
[ 1.8202178 , -0.37819335]], dtype=float32))
brainpy.math.random.randint(low, high[, size, dtype])
This function of random module is used to generate random integers from inclusive(low) to exclusive(high).
bm.random.randint(0, 3, size=10)
JaxArray(DeviceArray([0, 1, 1, 2, 0, 1, 0, 2, 0, 2], dtype=int32))
brainpy.math.random.random([size])
This function of random module is used to generate random floats number in the half-open interval [0.0, 1.0).
bm.random.random((3, 2))
JaxArray(DeviceArray([[0.76483357, 0.559957 ],
[0.50227726, 0.41693842],
[0.65068877, 0.8199152 ]], dtype=float32))
brainpy.math
module also provides permutation functions.
brainpy.math.random.shuffle()
This function is used for modifying a sequence in-place by shuffling its contents.
bm.random.shuffle( bm.arange(10) )
JaxArray(DeviceArray([8, 4, 9, 5, 7, 0, 3, 6, 1, 2], dtype=int32))
brainpy.math.random.permutation()
This function permute a sequence randomly or return a permuted range.
bm.random.permutation( bm.arange(10) )
JaxArray(DeviceArray([2, 1, 0, 9, 5, 4, 7, 6, 8, 3], dtype=int32))
brainpy.math
module also provides functions to sample distributions.
beta(a, b[, size])
This function is used to draw samples from a Beta distribution.
bm.random.beta(2, 3, 10)
JaxArray(DeviceArray([0.48173192, 0.09183226, 0.5617174 , 0.4964077 , 0.5717186 ,
0.60861576, 0.3472139 , 0.58446443, 0.41256 , 0.07920451], dtype=float32))
exponential([scale, size])
This function is used to draw sample from an exponential distribution.
bm.random.exponential(1, 10)
JaxArray(DeviceArray([1.5618182 , 0.18306465, 1.0619484 , 1.2519189 , 0.6019476 ,
1.0401233 , 0.37211612, 0.06336975, 3.796705 , 0.03766083], dtype=float32))
More sampling methods please see random sampling functions.
And more
Moreover, there are many other methods we can use to create tensors, including:
Conversion from other Python structures (i.e. lists and tuples)
Intrinsic NumPy array creation functions (e.g.
arange
,ones
,zeros
, etc.)Use of special library functions (e.g.,
random
)Replicating, joining, or mutating existing tensors
Reading arrays from disk, either from standard or custom formats
Creating arrays from raw bytes through the use of strings or buffers
Detail of these methods please see NumPy tutorial: Array creation. Most of these methods are supported in BrainPy.
Supported operations on tensor
All the operations in BrainPy are based on tensors. Therefore it is necessary to know what operations supported in each tensor object.
Basic operations
Arithmetic operators on tensors apply element-wise. Let’s take “+”, “-”, “*”, and “/” as examples.
We first create two tensors:
data = bm.array([1, 2])
data
JaxArray(DeviceArray([1, 2], dtype=int32))
ones = bm.ones(2)
ones
JaxArray(DeviceArray([1., 1.], dtype=float32))
data + ones
JaxArray(DeviceArray([2., 3.], dtype=float32))
data - ones
JaxArray(DeviceArray([0., 1.], dtype=float32))
data * data
JaxArray(DeviceArray([1, 4], dtype=int32))
data / data
JaxArray(DeviceArray([1., 1.], dtype=float32))
Aggregation functions can also be performed on tensors, like:
.min()
: get the minimum element;.max()
: get the maximum element;.sum()
: get the summation;.mean()
: get the average;.prod()
: get the result of multiplying the elements together;.std()
: to get the standard deviation.
data = bm.array([1, 2, 3])
data.max()
DeviceArray(3, dtype=int32)
data.min()
DeviceArray(1, dtype=int32)
data.sum()
DeviceArray(6, dtype=int32)
It’s very common to want to aggregate along a row or column. You can specify on which axis you want the aggregation function to be computed. For example, you can find the maximum value within each column by specifying axis=0
.
a = bm.array([[1, 2],
[5, 3],
[4, 6]])
a.max(axis=0)
JaxArray(DeviceArray([5, 6], dtype=int32))
a.max(axis=1)
JaxArray(DeviceArray([2, 5, 6], dtype=int32))
Broadcasting
Tensor operations are usually done on pairs of arrays on an element-by-element basis. In the simplest case, the two tensors must have exactly the same shape, as in the following example:
a = bm.array([1.0, 2.0, 3.0])
b = bm.array([2.0, 2.0, 2.0])
a * b
JaxArray(DeviceArray([2., 4., 6.], dtype=float32))
However, broadcasting rule relaxes this constraint when the tensor’ shapes meet certain constraints. The simplest broadcasting example occurs when an tensor and a scalar value are combined in an operation:
a = bm.array([1, 2])
b = 1.6
a * b
JaxArray(DeviceArray([1.6, 3.2], dtype=float32, weak_type=True))
Similarly, broadcasting can happens on matrix. Below is an example.
data = bm.array([[1, 2],
[3, 4],
[5, 6]])
ones_row = bm.array([[1, 1]])
data + ones_row
JaxArray(DeviceArray([[2, 3],
[4, 5],
[6, 7]], dtype=int32))
Under certain constraints, the smaller tensor can be “broadcast” across the larger tensor so that they have compatible shapes. Broadcasting provides a means of vectorizing tensor operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations.
Generally, the dimensions of two tensors are compatible when
they are equal, or
one of them is 1
one of them has less number of dimensions
If these conditions are not met, an error will happen.
For example, according to the broadcast rules, the following two shapes are compatible:
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
image = bm.random.random((256, 256, 3))
scale = bm.random.random(3)
_ = image + scale
_ = image - scale
_ = image * scale
_ = image / scale
These shapes are also compatible:
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5
A = bm.random.random((8, 1, 6, 1))
B = bm.random.random((7, 1, 5))
_ = A + B
_ = A - B
_ = A * B
_ = A / B
However, these examples of shapes do not broadcast:
A (1d array): 3
B (1d array): 4 # trailing dimensions do not match
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # second from last dimensions mismatched
A = bm.random.random((3,))
B = bm.random.random((4,))
try:
_ = A + B
except Exception as e:
print(e)
add got incompatible shapes for broadcasting: (3,), (4,).
A = bm.random.random((2, 1))
B = bm.random.random((8, 4, 3))
try:
_ = A + B
except Exception as e:
print(e)
Incompatible shapes for broadcasting: ((1, 2, 1), (8, 4, 3))
More details about broadcasting please see NumPy documentation: broadcasting.
Indexing, Slicing and Iterating
Any tensors can be indexed, sliced and iterated over, much like lists and other Python sequences. For examples:
a = bm.arange(10) ** 3
a
JaxArray(DeviceArray([ 0, 1, 8, 27, 64, 125, 216, 343, 512, 729], dtype=int32))
a[2]
DeviceArray(8, dtype=int32)
a[2:5]
DeviceArray([ 8, 27, 64], dtype=int32)
# from start to position 6, exclusive, set every 2nd element to 1000,
# equivalent to a[0:6:2] = 1000
a[:6:2] = 1000
a
JaxArray(DeviceArray([1000, 1, 1000, 27, 1000, 125, 216, 343, 512, 729], dtype=int32))
a[::-1] # reversed a
DeviceArray([ 729, 512, 343, 216, 125, 1000, 27, 1000, 1, 1000], dtype=int32)
for i in a: # iterate a
print(i**(1 / 3.))
10.000001
1.0
10.000001
3.0
10.000001
5.0000005
6.0000005
7.0000005
8.000001
9.000001
For multi-dimensional tensors, these indices should be given in a tuple separated by commas. For example,
b = bm.arange(20).reshape((5, 4))
b[2, 3]
DeviceArray(11, dtype=int32)
b[0:5, 1] # each row in the second column of b
DeviceArray([ 1, 5, 9, 13, 17], dtype=int32)
b[:, 1] # equivalent to the previous example
DeviceArray([ 1, 5, 9, 13, 17], dtype=int32)
b[1:3, :] # each column in the second and third row of b
DeviceArray([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
When fewer indices are provided than the number of axes, the missing indices are considered complete slices:
b[-1] # the last row. Equivalent to b[-1, :]
DeviceArray([16, 17, 18, 19], dtype=int32)
You can also write this using dots as b[i, ...]
. The dots (...)
represent as many colons as needed to produce a complete indexing tuple. For example, if x is an array with 5 axes, then
x[1, 2, ...]
is equivalent tox[1, 2, :, :, :]
,x[..., 3]
tox[:, :, :, :, 3]
andx[4, ..., 5, :]
tox[4, :, :, 5, :]
.
c = bm.arange(48).reshape((6, 4, 2))
c[1, ...] # same as c[1, :, :] or c[1]
DeviceArray([[ 8, 9],
[10, 11],
[12, 13],
[14, 15]], dtype=int32)
c[..., 2] # same as c[:, :, 2]
DeviceArray([[ 1, 3, 5, 7],
[ 9, 11, 13, 15],
[17, 19, 21, 23],
[25, 27, 29, 31],
[33, 35, 37, 39],
[41, 43, 45, 47]], dtype=int32)
Iterating over multidimensional tensors is done with respect to the first axis:
for row in b:
print(row)
[0 1 2 3]
[4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]
[16 17 18 19]
More methods or advanced indexing and index tricks please see NumPy tutorial: Indexing.
Mathematical functions
Tensors support many other functions, including
and others …
Most of these functions can be found in brainpy.math
module. Let’s take a look at trigonometric, hyperbolic, rounding functions.
d = bm.linspace(0, 1, 10)
d
JaxArray(DeviceArray([0. , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1. ], dtype=float32))
# trigonometric functions
bm.sin(d)
JaxArray(DeviceArray([0. , 0.11088263, 0.22039774, 0.32719472, 0.42995638,
0.5274154 , 0.6183698 , 0.7016979 , 0.7763719 , 0.84147096], dtype=float32))
bm.arcsin(d)
JaxArray(DeviceArray([0. , 0.11134101, 0.2240931 , 0.33983693, 0.46055397,
0.589031 , 0.7297277 , 0.8911225 , 1.0949141 , 1.5707964 ], dtype=float32))
# hyperbolic functions
bm.sinh(d)
JaxArray(DeviceArray([0. , 0.11133985, 0.22405571, 0.33954054, 0.45922154,
0.58457786, 0.7171585 , 0.8586021 , 1.0106566 , 1.1752012 ], dtype=float32))
# rounding functions
bm.round(d)
JaxArray(DeviceArray([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], dtype=float32))
# sum function
bm.sum(d)
DeviceArray(5., dtype=float32)
Variables
In BrainPy, the JIT compilation for class objects relies on Variable. In this section, we are going to understand:
what is
Variable
?the subtypes of
Variable
?
import brainpy as bp
import brainpy.math as bm
Variable
brainpy.math.Variable
is a pointer refers to a tensor. It stores the value of the tensor. The concrete value in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed.
During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error.
Create a Variable
Passing a tensor into the brainpy.math.Variable
creates a Variable, for example:
bm.use_backend('numpy')
a1 = bm.random.random(5)
a2 = bm.Variable(a1)
a1, a2
(array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]),
Variable([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]))
bm.use_backend('jax')
b1 = bm.random.random(5)
b2 = bm.Variable(b1)
b1, b2
(JaxArray(DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)),
Variable(DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)))
Access the value in a Variable
The concrete value of a Variable can be obtained through .value
.
a2.value
array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844])
(a2.value == a1).all()
True
b2.value
DeviceArray([0.70530474, 0.99841356, 0.815271 , 0.926391 , 0.84018004], dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
Supported operations on a Variable
A Variable support almost all the operations for a tensor. Actually, brainpy.math.Variable
is a subclass of brainpy.math.ndarray
.
isinstance(a2, bp.math.numpy.ndarray)
True
isinstance(b2, bp.math.jax.ndarray)
True
isinstance(b2, bp.math.jax.JaxArray)
True
# `bp.math.jax.ndarray` is an alias for `bp.math.jax.JaxArray` in 'jax' backend
bp.math.jax.ndarray is bp.math.jax.JaxArray
True
Note
In ‘jax’ backend, after performing any operation on a Variable, the resulting value will be a JaxArray (bp.math.jax.ndarray
is an alias for bp.math.jax.JaxArray
in ‘jax’ backend). This means that the Variable can only be used to refer to a value.
b2 + 1.
JaxArray(DeviceArray([1.7053047, 1.9984136, 1.815271 , 1.926391 , 1.84018 ], dtype=float32))
b2 ** 2
JaxArray(DeviceArray([0.4974548 , 0.9968296 , 0.66466683, 0.8582003 , 0.7059025 ], dtype=float32))
bp.math.jax.floor(b2)
JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))
Subtypes of Variable
brainpy.math.Variable
has several subtypes, including brainpy.math.TrainVar
and brainpy.math.Parameter
. Subtypes can also be customized and extended by the user. We are going to talk about this.
TrainVar
brainpy.math.TrainVar
is a trainable variable (a subclass of brainpy.math.Variable
). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose.
bm.use_backend('numpy')
a = bm.random.rand(4)
a, bm.TrainVar(a)
(array([0.81515042, 0.40363449, 0.89924935, 0.29827197]),
TrainVar([0.81515042, 0.40363449, 0.89924935, 0.29827197]))
bm.use_backend('jax')
b = bm.random.rand(4)
b, bm.TrainVar(b)
(JaxArray(DeviceArray([0.4008 , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)),
TrainVar(DeviceArray([0.4008 , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)))
Parameter
brainpy.math.Parameter
is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable
. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets
method (please see Base class).
bm.use_backend('numpy')
a = bm.random.rand(1)
a, bm.Parameter(a)
(array([0.5776296]), Parameter([0.5776296]))
bm.use_backend('jax')
b = bm.random.rand(1)
b, bm.Parameter(b)
(JaxArray(DeviceArray([0.61128676], dtype=float32)),
Parameter(DeviceArray([0.61128676], dtype=float32)))
RandomState
In ‘jax’ backend, brainpy.math.random.RandomState
is also a subclass of brainpy.math.Variable
. This is because the RandomState in ‘jax’ backend must store the dynamically changed key information. Every time after a RandomState performs a random sampling, the “key” will change. For example,
bm.use_backend('jax')
state = bm.random.RandomState(seed=1234)
state
RandomState(DeviceArray([ 0, 1234], dtype=uint32))
# perform a "random" sampling
state.random(1)
# the value changed
state
RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling
state.sample(1)
# the value changed too
state
RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))
Every instance of RandomState can create a new seed from the current seed with .split_key()
.
state.split_key()
DeviceArray([3028232624, 826525938], dtype=uint32)
It can also create multiple seeds from the current seed with .split_keys(n)
. This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.
state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)
There is a default RandomState in brainpy.math.jax.random
module: DEFAULT
.
bm.random.DEFAULT
RandomState(DeviceArray([2580684476, 2503630841], dtype=uint32))
The inherent random methods like randint()
, rand()
, shuffle()
, etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed()
method.
bm.random.seed(654321)
bm.random.DEFAULT
RandomState(DeviceArray([ 0, 654321], dtype=uint32))
Base Class
In this section, we are going to talk about:
Base
class for BrainPy ecosystem,Collector
to facilitate variable collection and manipulation.
import brainpy as bp
import brainpy.math as bm
Base
The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatic differentiated. Or we can say, any class objects want to JIT compile or auto differentiate must inherent from brainpy.Base
.
A Base object can have many variables, children Base objects, integrators, and methods. For example, let’s implement a FitzHugh-Nagumo neuron model.
class FHN(bp.Base):
def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FHN, self).__init__(name=name)
# parameters
self.num = num
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth
# variables
self.V = bm.Variable(bm.zeros(num))
self.w = bm.Variable(bm.zeros(num))
self.spike = bm.Variable(bm.zeros(num, dtype=bool))
# integral
self.integral = bp.odeint(method='rk4', f=self.derivative)
def derivative(self, V, w, t, Iext):
dw = (V + self.a - self.b * w) / self.tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
def update(self, _t, _dt, x):
V, w = self.integral(self.V, self.w, _t, x)
self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
self.w[:] = w
self.V[:] = V
Note this model has three variables: self.V
, self.w
, and self.spike
. It also has an integrator self.integral
.
Naming system
Every Base object has a unique name. You can specify a unique name when you instantiate a Base class. A used name will cause an error.
FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
FHN(10, name='Y').name
except Exception as e:
print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x7f4a7406bd60> has a used name "Y".
When you instance a Base class without “name” specification, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name + number_of_instances
. For example, FHN0
, FHN1
, etc.
FHN(10).name
'FHN0'
FHN(10).name
'FHN1'
Therefore in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.
Collection functions
Three important collection functions are implemented for each Base object. Specifically, they are:
nodes()
: to collect all instances of Base objects, including children nodes in a node.ints()
: to collect all integrators defined in the Base node and in its children nodes.vars()
: to collect all variables defined in the Base node and in its children nodes.
All integrators can be collected through one method Base.ints()
. The result container is a Collector.
fhn = FHN(10)
ints = fhn.ints()
ints
{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>}
type(ints)
brainpy.base.collector.Collector
Similarly, all variables in a Base object can be collected through Base.vars()
. The returned container is a TensorCollector (a subclass of Collector
).
vars = fhn.vars()
vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
type(vars)
brainpy.base.collector.TensorCollector
All nodes in the model can also be collected through one method Base.nodes()
. The result container is an instance of Collector.
nodes = fhn.nodes()
nodes # note: integrator is also a node
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>,
'FHN2': <__main__.FHN at 0x7f4a7406b5e0>}
type(nodes)
brainpy.base.collector.Collector
Now, let’s make a more complicated model by using the previously defined model FHN
.
class FeedForwardCircuit(bp.Base):
def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FeedForwardCircuit, self).__init__(name=name)
self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
conn = bm.ones((num1, num2), dtype=bool)
self.conn = bm.fill_diagonal(conn, False) * w
def update(self, _t, _dt, x):
self.pre.update(_t, _dt, x)
x2 = self.pre.spike @ self.conn
self.post.update(_t, _dt, x2)
This model FeedForwardCircuit
defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN
). The first layer is densely connected to the second layer. The input to the second layer is the first layer’s spike times a connection strength w
.
net = FeedForwardCircuit(8, 5)
We can retrieve all integrators in the network with .ints()
:
net.ints()
{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
Or, retrieve all variables by .vars()
:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN4.V': Variable([0., 0., 0., 0., 0.]),
'FHN4.spike': Variable([False, False, False, False, False]),
'FHN4.w': Variable([0., 0., 0., 0., 0.])}
Or, retrieve all nodes (instances of Base class) with .nodes()
:
net.nodes()
{'FHN3': <__main__.FHN at 0x7f4a74077670>,
'FHN4': <__main__.FHN at 0x7f4a740771c0>,
'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>,
'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x7f4a74077790>}
Absolute path
It’s worthy to note that there are two types of ways to access variables, integrators, and nodes. They are “absolute” path and “relative” path. The default way is the absolute path.
“Absolute” path means that all keys in the resulting Collector (Base.nodes()
) has the format of key = node_name [+ field_name]
.
.nodes() example 1: In the above fhn
instance, there are two nodes: “fnh” and its integrator “fhn.integral”.
fhn.integral.name, fhn.name
('RK44', 'FHN2')
Calling .nodes()
returns models’ name and models.
fhn.nodes().keys()
dict_keys(['RK44', 'FHN2'])
.nodes() example 2: In the above net
instance, there are five nodes:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')
Calling .nodes()
also returns the names and instances of all models.
net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])
.vars() example 1: In the above fhn
instance, there are three variables: “V”, “w” and “input”. Calling .vars()
returns a dict of <node_name + var_name, var_value>
.
fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
.vars() example 2: This also applies in the net
instance:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN4.V': Variable([0., 0., 0., 0., 0.]),
'FHN4.spike': Variable([False, False, False, False, False]),
'FHN4.w': Variable([0., 0., 0., 0., 0.])}
Relative path
Variables, integrators, and nodes can also be accessed by relative path. For example, the pre
instance in the net
can be accessed by
net.pre
<__main__.FHN at 0x7f4a74077670>
Relative path preserves the dependence relationship. For example, all nodes retrieved from the perspective of net
are:
net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x7f4a74077790>,
'pre': <__main__.FHN at 0x7f4a74077670>,
'post': <__main__.FHN at 0x7f4a740771c0>,
'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
However, nodes retrieved from the start point of net.pre
will be:
net.pre.nodes('relative')
{'': <__main__.FHN at 0x7f4a74077670>,
'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>}
Variables can also relatively inferred from the model. For example, all variables one can relatively accessed from net
are:
net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'pre.spike': Variable([False, False, False, False, False, False, False, False]),
'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'post.V': Variable([0., 0., 0., 0., 0.]),
'post.spike': Variable([False, False, False, False, False]),
'post.w': Variable([0., 0., 0., 0., 0.])}
While, variables relatively accessed from the view of net.post
are:
net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.]),
'spike': Variable([False, False, False, False, False]),
'w': Variable([0., 0., 0., 0., 0.])}
Elements in containers
To avoid surprising unintended behaviors, collection functions don’t look for elements in list, dict or any container structure.
class ATest(bp.Base):
def __init__(self):
super(ATest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()
The above class define a list of variables, and a dict of children nodes. However, they can not be retrieved from the collection functions vars()
and nodes()
.
t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x7f4a7401b2b0>}
Fortunately, in BrianPy, we provide implicit_vars
and implicit_nodes
(an instance of “dict”) to hold variables and nodes in container structures. Any variable registered in implicit_vars
, or any integrator or node registered in implicit_nodes
can be retrieved by collection functions. Let’s make a try.
class AnotherTest(bp.Base):
def __init__(self):
super(AnotherTest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
self.implicit_vars = {f'v{i}': v for i, v in enumerate(self.all_vars)} # must be a dict
self.implicit_nodes = {k: v for k, v in self.sub_nodes.items()} # must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator.
# Therefore, there are five Base objects.
t2.nodes()
{'T1': <__main__.FHN at 0x7f4a740777c0>,
'T2': <__main__.FHN at 0x7f4a74077a90>,
'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74077340>,
'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74011040>,
'AnotherTest0': <__main__.AnotherTest at 0x7f4a740157f0>}
# This model has five Base objects (seen above),
# each FHN node has three variables,
# moreover, this model has two implicit variables.
t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'T2.V': Variable([0., 0., 0., 0., 0.]),
'T2.spike': Variable([False, False, False, False, False]),
'T2.w': Variable([0., 0., 0., 0., 0.]),
'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.]),
'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.])}
Saving and loading
Because Base.vars()
returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (more details please see Saving and Loading). Specifically, they are implemented by Base.save_states()
and Base.load_states()
.
Save
Base.save_states(PATH, [vars])
Model exporting in BrainPy supports various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’spickle
utility).mat
(Matlab file format)
net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
# Unknown file format will cause error
try:
net.save_states('./data/net.xxx')
except Exception as e:
print(type(e).__name__, ":", e)
BrainPyError : Unknown file format: ./data/net.xxx. We only supports ['.h5', '.hdf5', '.npz', '.pkl', '.mat']
Load
Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')
Collector
Collection functions returns an brainpy.Collector
. This class is a dictionary that maps names to elements. It has some useful methods.
subset()
Collector.subset(cls)
returns a part of elements whose type is the given cls
. For example, Base.nodes()
returns all instances of Base class. If you are only interested in one type, like ODEIntegrator
, you can use:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
Actually, Collector.subset(cls)
travels all the elements in this collection, and find the element whose type matches to the given cls
.
unique()
It’s a common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations twice or more to the same elements (e.g., apply gradients and update weights).
Therefore, Collector provides method Collector.unique()
to handle this automatically. Collector.unique()
returns a copy of collection in which all elements are unique.
class ModelA(bp.Base):
def __init__(self):
super(ModelA, self).__init__()
self.a = bm.Variable(bm.zeros(5))
class SharedA(bp.Base):
def __init__(self, source):
super(SharedA, self).__init__()
self.source = source
self.a = source.a # shared variable
class Group(bp.Base):
def __init__(self):
super(Group, self).__init__()
self.A = ModelA()
self.A_shared = SharedA(self.A)
g = Group()
g.vars('relative') # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.]),
'A_shared.a': Variable([0., 0., 0., 0., 0.]),
'A_shared.source.a': Variable([0., 0., 0., 0., 0.])}
g.vars('relative').unique() # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.])}
g.nodes('relative') # "ModelA" is accessed twice
{'': <__main__.Group at 0x7f4a74049190>,
'A': <__main__.ModelA at 0x7f4a740490a0>,
'A_shared': <__main__.SharedA at 0x7f4a74049550>,
'A_shared.source': <__main__.ModelA at 0x7f4a740490a0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x7f4a74049190>,
'A': <__main__.ModelA at 0x7f4a740490a0>,
'A_shared': <__main__.SharedA at 0x7f4a74049550>}
update()
Collector is a dict. But, it has means to catch potential conflicts during assignment. The bracket assignment of a Collector ([key]
) and Collector.update()
will check whether the same key maps to a different value. If yes, an error will raise.
tc = bp.Collector({'a': bm.zeros(10)})
tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
try:
tc['a'] = bm.zeros(1) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
tc.update({'a': bm.ones(1)}) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
replace()
If you try to replace the old key with the new value, you should use Collector.replace(old_key, new_value)
function.
tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
tc.replace('a', bm.ones(3))
tc
{'a': array([1., 1., 1.])}
TensorCollector
TensorCollector
is subclass of Collector
, but it is specifically to collect tensors.
Compilation
In this section, we are going to talk about the concept of the code compilation to accelerate your model running performance.
import brainpy as bp
import brainpy.math.jax as bm
bp.math.use_backend('jax')
jit()
We have talked about the mechanism of JIT compilation for class objects in NumPy backend. In this section, we try to understand how to apply JIT when you are using JAX backend.
jax.jit()
is excellent, while it only supports pure functions. brainpy.math.jax.jit()
is based on jax.jit()
, but extends its ability to just-in-time compile your class objects.
JIT for pure functions
First, brainpy.math.jax.jit()
can just-in-time compile your functions.
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * bm.where(x > 0, x, alpha * bm.exp(x) - alpha)
x = bm.random.normal(size=(1000000,))
%timeit selu(x)
2.86 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
selu_jit = bm.jit(selu) # jit accleration
%timeit selu_jit(x)
346 µs ± 21.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
JIT for class objects
Moreover, brainpy.math.jax.jit()
is powerful to just-in-time compile your class objects. The constraints for class object JIT are:
The JIT target must be a subclass of
brainpy.Base
.Dynamically changed variables must be labeled as
brainpy.math.Variable
.Variable changes must be made in-place.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
num_dim, num_points = 10, 200000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr = LogisticRegression(10)
%timeit lr(points, labels)
2.77 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
lr_jit = bm.jit(lr)
%timeit lr_jit(points, labels)
1.29 ms ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
JIT mechanism
The mechanism of JIT compilation is that BrainPy automatically transforms your class methods into functions.
brainpy.math.jax.jit()
receives a dyn_vars
argument, which denotes the dynamically changed variables. If you do not provide it, BrainPy will automatically detect them by calling Base.vars()
. Once get “dyn_vars”, BrainPy will treat “dyn_vars” as function arguments, thus making them able to dynamically change.
import types
isinstance(lr_jit, types.FunctionType) # "lr" is class, while "lr_jit" is a function
True
Therefore, the secrete of brainpy.math.jax.jit()
is providing “dyn_vars”. No matter your target is a class object, a method in the class object, or a pure function, if there are dynamically changed variables, you just pack them into brainpy.math.jax.jit()
as “dyn_vars”. Then, all the compilation and acceleration will be handled by BrainPy automatically.
Example 1: JIT a class method
class Linear(bp.Base):
def __init__(self, n_in, n_out):
super(Linear, self).__init__()
self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
self.b = bm.TrainVar(bm.zeros(n_out))
def update(self, x):
return x @ self.w + self.b
x = bm.zeros(10)
l = Linear(10, 3)
This time, we mark “w” and “b” as dynamically changed variables.
update1 = bm.jit(l.update, dyn_vars=[l.w, l.b]) # make 'w' and 'b' dynamically change
update1(x) # x is 0., b is 0., therefore y is 0.
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
l.b[:] = 1. # change b to 1, we expect y will be 1 too
update1(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
This time, we only mark “w” as dynamically changed variables. We will find also modify “b”, the results will not change.
update2 = bm.jit(l.update, dyn_vars=[l.w]) # make 'w' dynamically change
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
l.b[:] = 2. # change b to 2, while y will not be 2
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
Example 2: JIT a function
Now, we change the above “Linear” object to a function.
n_in = 10; n_out = 3
w = bm.TrainVar(bm.random.random((n_in, n_out)))
b = bm.TrainVar(bm.zeros(n_out))
def update(x):
return x @ w + b
If we do not provide dyn_vars
, “w” and “b” will be compiled as constant values.
update1 = bm.jit(update)
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
b[:] = 1. # modify the value of 'b' will not
# change the result, because in the
# jitted function, 'b' is already
# a constant
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
Provide “w” and “b” as dyn_vars
will make them dynamically changed again.
update2 = bm.jit(update, dyn_vars=(w, b))
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
b[:] = 2. # change b to 2, while y will not be 2
update2(x)
JaxArray(DeviceArray([2., 2., 2.], dtype=float32))
RandomState
We have talked about RandomState in Variables section. We said that it is also a Variable. Therefore, if your functions have used the default RandomState (brainpy.math.jax.random.DEFAULT
), you should add it into the dyn_vars
scope of the function. Otherwise, they will be treated as constants and the jitted function will always return the same value.
def function():
return bm.random.normal(0, 1, size=(10,))
f1 = bm.jit(function)
f1() == f1()
JaxArray(DeviceArray([ True, True, True, True, True, True, True, True,
True, True], dtype=bool))
The correct way to make JIT for this function is:
bm.random.seed(1234)
f2 = bm.jit(function, dyn_vars=bm.random.DEFAULT)
f2() == f2()
JaxArray(DeviceArray([False, False, False, False, False, False, False, False,
False, False], dtype=bool))
Example 3: JIT a neural network
Now, let’s use SGD to train a neural network with JIT acceleration. Here we will use the autograd function brainpy.math.jax.grad()
, which will be detailed out in the next section.
class LinearNet(bp.Base):
def __init__(self, n_in, n_out):
super(LinearNet, self).__init__()
# weights
self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
self.b = bm.TrainVar(bm.zeros(n_out))
self.r = bm.TrainVar(bm.random.random((n_out, 1)))
def update(self, x):
h = x @ self.w + self.b
return h @ self.r
def loss(self, x, y):
predict = self.update(x)
return bm.mean((predict - y) ** 2)
ln = LinearNet(100, 200)
# provide the variables want to update
opt = bm.optimizers.SGD(lr=1e-6, train_vars=ln.vars())
# provide the variables require graidents
f_grad = bm.grad(ln.loss, grad_vars=ln.vars(), return_value=True)
def train(X, Y):
grads, loss = f_grad(X, Y)
opt.update(grads)
return loss
# JIT the train function
train_jit = bm.jit(train, dyn_vars=ln.vars() + opt.vars())
xs = bm.random.random((1000, 100))
ys = bm.random.random((1000, 1))
for i in range(30):
loss = train_jit(xs, ys)
print(f'Train {i}, loss = {loss:.2f}')
Train 0, loss = 103.74
Train 1, loss = 63.50
Train 2, loss = 40.54
Train 3, loss = 27.44
Train 4, loss = 19.97
Train 5, loss = 15.71
Train 6, loss = 13.27
Train 7, loss = 11.89
Train 8, loss = 11.09
Train 9, loss = 10.64
Train 10, loss = 10.38
Train 11, loss = 10.24
Train 12, loss = 10.15
Train 13, loss = 10.10
Train 14, loss = 10.08
Train 15, loss = 10.06
Train 16, loss = 10.05
Train 17, loss = 10.05
Train 18, loss = 10.04
Train 19, loss = 10.04
Train 20, loss = 10.04
Train 21, loss = 10.04
Train 22, loss = 10.04
Train 23, loss = 10.04
Train 24, loss = 10.04
Train 25, loss = 10.04
Train 26, loss = 10.04
Train 27, loss = 10.04
Train 28, loss = 10.04
Train 29, loss = 10.04
Static arguments
Static arguments are arguments that are treated as static/constant in the jitted function.
Numerical arguments used in condition syntax (bool value or resulting bool value), and strings must be marked as static. Otherwise, an error will raise.
@bm.jit
def f(x):
if x < 3: # this will cause error
return 3. * x ** 2
else:
return -4 * x
try:
f(1.)
except Exception as e:
print(type(e), e)
<class 'jax._src.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at <ipython-input-70-14a993a83941>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Simply speaking, arguments resulting boolean values must be declared as static arguments. In brainpy.math.jax.jit()
function, if can set the names of static arguments.
def f(x):
if x < 3: # this will cause error
return 3. * x ** 2
else:
return -4 * x
f_jit = bm.jit(f, static_argnames=('x', ))
f_jit(x=1.)
DeviceArray(3., dtype=float32, weak_type=True)
However, it’s worthy noting that calling the jitted function with different values for these static arguments will trigger recompilation. Therefore, declaring static arguments may be suitable to the following situations:
Boolean arguments.
Arguments only have several possible values.
If the argument value change significantly, you’d better not to declare it as static.
For more information, please refer to jax.jit API.
vmap()
Coming soon.
pmap()
Coming soon.
Differentiation
In this section, we are going to talk about how to make auto differentiation on your functions and class objects with ‘jax’ backend. In nowadays machine learning systems, computing and using gradients are common in various situations. So, we try to understand
how to calculate derivatives of arbitrary complex functions,
how to compute high-order gradients.
import brainpy as bp
import brainpy.math.jax as bm
bp.math.use_backend('jax')
All autodiff functions in BrainPy support pure functions and class objects.
grad()
brainpy.math.jax.grad()
takes a function/object and returns a new function which computes the gradient of the original function/object.
Pure functions
For pure function, the gradient is taken with respect to the first argument:
def f(a, b):
return a * 2 + b
grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32)
However, this can be controlled via the argnums
argument.
grad_f2 = bm.grad(f, argnums=(0, 1))
grad_f2(2., 1.)
(DeviceArray(2., dtype=float32), DeviceArray(1., dtype=float32))
Class objects
For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars
argument:
class F(bp.Base):
def __init__(self):
super(F, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
vv = ab2 + c
return vv.mean()
f = F()
The grad_vars
can be a JaxArray, or a list/tuple/dict of JaxArray.
bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': TrainVar(DeviceArray([2.], dtype=float32)),
'F0.b': TrainVar(DeviceArray([2.], dtype=float32))}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
[TrainVar(DeviceArray([2.], dtype=float32)),
TrainVar(DeviceArray([2.], dtype=float32))]
If there are values dynamically changed in the gradient function, you can provide them in the dyn_vars
argument.
class F2(bp.Base):
def __init__(self):
super(F2, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab = ab * 2
self.a.value = ab
return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
TrainVar(DeviceArray([2.], dtype=float32))
Also, if you are interested with the gradient of the input value, please use argnums
argument. For this situation, calling the gradient function will return (grads_of_grad_vars, *grads_of_args)
.
class F3(bp.Base):
def __init__(self):
super(F3, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c, d):
ab = self.a * self.b
ab = ab * 2
return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_arg0 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grads_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_arg0, grad_of_arg1 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
print("grad_of_arg1 :", grad_of_arg1)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grad_of_arg0 : 3.0
grad_of_arg1 : 10.0
Actually, we recommend you to provide any dynamically changed variables (no matter them are updated in the gradient function) in the dyn_vars
argument.
Auxiliary data
Usually, we want to get the value of the loss, or, we want to return some intermediate variables during the gradient computation. For them situation, users can set has_aux=True
to return auxiliary data, and set return_value=True
to return loss value.
# return loss
grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)
print('grad: ', grad)
print('loss: ', loss)
grad: TrainVar(DeviceArray([2.], dtype=float32))
loss: 12.0
class F4(bp.Base):
def __init__(self):
super(F4, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
loss = (ab + c).mean()
return loss, (ab, ab2)
f4 = F4()
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)
print('grad: ', grad)
print('aux_data: ', aux_data)
grad: TrainVar(DeviceArray([1.], dtype=float32))
aux_data: (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32)))
Note: Any function wants to compute gradients through brainpy.math.jax.grad()
must return a scalar value. Otherwise an error will raise.
try:
bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output was [0. 0.].
# this is right
bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray(DeviceArray([0.5, 0.5], dtype=float32))
If you want to take gradients for a vector-output values, please use brainpy.math.jax.jacobian()
function.
jacobian()
Coming soon.
hessian()
Coming soon.
Control Flows
In this section, we are going to talk about how to build structured control flows in ‘jax’ backend. These control flows include
for loop syntax,
while loop syntax,
and condition syntax.
import brainpy as bp
import brainpy.math.jax as bm
bp.math.use_backend('jax')
In JAX, the control flow syntaxes are not easy to use. Users must transform the intuitive Python control flows into structured control flows.
Based on JaxArray
provided in BrainPy, we try to present a better syntax to make control flows.
make_loop()
brainpy.math.jax.make_loop()
is used to generate a for-loop function when you are using JaxArray
.
Let’s image your requirement: you are using several JaxArray (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes, your body function return something, and you also want to gather the return values. With Python syntax, your requirement is equivalent to
def for_loop_function(body_fun, dyn_vars, out_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' and 'out_vars'
# are updated in 'body_fun()'
results = body_fun(x)
ys.append([out_vars, results])
return ys
In BrainPy, using brainpy.math.jax.make_loop()
you can define this logic like:
loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=False)
hist_of_out_vars = loop_fun(xs)
Or,
loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=True)
hist_of_out_vars, hist_of_return_vars = loop_fun(xs)
Let’s implement a recurrent network to illustrate how to use this function.
class RNN(bp.DynamicalSystem):
def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
super(RNN, self).__init__(**kwargs)
# parameters
self.n_in = n_in
self.n_h = n_h
self.n_out = n_out
self.n_batch = n_batch
self.g = g
# weights
self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
self.b_ro = bm.TrainVar(bm.zeros((n_out,)))
# variables
self.h = bm.Variable(bm.random.random((n_batch, n_h)))
# function
self.predict = bm.make_loop(self.cell,
dyn_vars=self.vars(),
out_vars=self.h,
has_return=True)
def cell(self, x):
self.h[:] = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
o = self.h @ self.w_ro + self.b_ro
return o
rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)
In the above RNN
model, we define a body function RNN.cell
for later for-loop over input values. The loop function is defined as self.predict
with bm.make_loop()
. We care about the history values of “self.h” and the readout value “o”, so we set out_vars = self.h
and has_return=True
.
xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape # the shape should be (num_time, ) + o.shape
(100, 5, 3)
If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs)
, where “xs” can be a JaxArray, list/tuple/dict of JaxArray. For examples:
a = bm.zeros(10)
def body(x):
x1, x2 = x # "x" is a tuple/list of JaxArray
a.value += (x1 + x2)
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
JaxArray(DeviceArray([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32))
a = bm.zeros(10)
def body(x): # "x" is a dict of JaxArray
a.value += x['a'] + x['b']
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
JaxArray(DeviceArray([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32))
dyn_vars
, out_vars
, xs
and body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_vars
and body function returns. If has_return=True
, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns)
. If no values are interested, please set out_vars=None
, and the loop function only returns hist_of_out_vars
.
make_while()
brainpy.math.jax.make_while()
is used to generate a while-loop function when you are using JaxArray
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.jax.make_while()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which does not support return values:
while cond_fun(x):
body_fun(x)
where x
is the external input which is not iterated. All the iterated variables should be marked as JaxArray
. All JaxArray
used in cond_fun
and body_fun
should be declared in a dyn_vars
variable.
Let’s look an example:
i = bm.zeros(1)
counter = bm.zeros(1)
def cond_f(x):
return i[0] < 10
def body_f(x):
i.value += 1.
counter.value += i
loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])
In the above, we try to implement a sum from 0 to 10. We use two JaxArray i
and counter
.
loop()
counter
JaxArray(DeviceArray([55.], dtype=float32))
i
JaxArray(DeviceArray([10.], dtype=float32))
make_cond()
brainpy.math.jax.make_cond()
is used to generate a condition function when you are using JaxArray
. It supports the following condition logic:
if True:
true statements
else:
false statements
When using brainpy.math.jax.make_cond()
, true statements should be wrapped as a true_fun
function which implements logics under true assert (no return), and false statements should be wrapped as a false_fun
function which implements logics under false assert (also does not support return values):
if True:
true_fun(x)
else:
false_fun(x)
All the JaxArray
used in true_fun
and false_fun
should be declared in the dyn_vars
argument. x
is also used to receive the external input value.
Let’s make a try:
a = bm.zeros(2)
b = bm.ones(2)
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
Here, we have two tensors. If true, tensor a
add 1; if false, tensor b
subtract 1.
cond(pred=True)
a, b
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(True)
a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False)
a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
JaxArray(DeviceArray([0., 0.], dtype=float32)))
cond(False)
a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
JaxArray(DeviceArray([-1., -1.], dtype=float32)))
Or, we define a conditional case which depends on the external input.
a = bm.zeros(2)
b = bm.ones(2)
def true_f(x): a.value += x
def false_f(x): b.value -= x
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)
a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False, 5.)
a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
JaxArray(DeviceArray([-4., -4.], dtype=float32)))
Optimizers
Gradient descent is one of the most popular algorithms to perform optimization. By far, gradient descent optimizers, combined with the loss function, are the key pieces that enable machine learning to work for your data. In this section, we are going to understand:
how to use optimizers in BrainPy?
how to customize your own optimizer?
import brainpy as bp
import brainpy.math.jax as bm
bp.math.use_backend('jax')
import matplotlib.pyplot as plt
Optimziers
The basic optimizer class in BrainPy is
bm.optimizers.Optimizer
brainpy.math.jax.optimizers.Optimizer
Following are some optimizers in BrainPy:
SGD
Momentum
Nesterov momentum
Adagrad
Adadelta
RMSProp
Adam
Users can also extent their own optimizers easily.
Generally, an Optimizer initialization receives a learning rate lr
, the trainable variables train_vars
, and other hyperparameters for the specific optimizer.
lr
can be a float, or an instance ofbm.optimizers.Scheduler
.train_vars
should be a dict of JaxArray.
Here we launch a SGD optimizer.
a = bm.ones((5, 4))
b = bm.zeros((3, 3))
op = bm.optimizers.SGD(lr=0.001, train_vars={'a': a, 'b': b})
When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update()
method.
op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})
print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.99970317, 0.99958736, 0.9991706 , 0.99929893],
[0.99986506, 0.9994412 , 0.9996797 , 0.9995855 ],
[0.99980134, 0.999285 , 0.99970514, 0.99927545],
[0.99907184, 0.9993837 , 0.99917775, 0.99953413],
[0.9999124 , 0.99908406, 0.9995969 , 0.9991523 ]], dtype=float32))
b: JaxArray(DeviceArray([[-5.8195234e-04, -4.3874790e-04, -3.3398748e-05],
[-5.7411409e-04, -7.0666044e-04, -9.4130711e-04],
[-7.1995187e-04, -1.1736620e-04, -9.5254736e-04]], dtype=float32))
You can process gradients before applying them. For example, we clip the graidents by the maximum L2-norm.
grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}
grads_pre
{'a': JaxArray(DeviceArray([[0.99927866, 0.03028023, 0.8803668 , 0.64568734],
[0.64056313, 0.04791141, 0.7399359 , 0.87378347],
[0.96773326, 0.7771431 , 0.9618045 , 0.8374212 ],
[0.64901245, 0.24517596, 0.06224799, 0.6327405 ],
[0.31687486, 0.6385107 , 0.9160483 , 0.67039466]], dtype=float32)),
'b': JaxArray(DeviceArray([[0.14722073, 0.52626574, 0.9817407 ],
[0.7333363 , 0.39472723, 0.82928896],
[0.7657701 , 0.93165004, 0.88332164]], dtype=float32))}
grads_post = bm.clip_by_norm(grads_pre, 1.)
grads_post
{'a': JaxArray(DeviceArray([[0.31979424, 0.00969043, 0.28173944, 0.20663615],
[0.20499626, 0.01533285, 0.23679803, 0.27963263],
[0.3096989 , 0.24870528, 0.30780157, 0.26799577],
[0.20770025, 0.07846245, 0.01992092, 0.20249282],
[0.1014079 , 0.20433943, 0.2931584 , 0.2145431 ]], dtype=float32)),
'b': JaxArray(DeviceArray([[0.0666547 , 0.23826863, 0.4444865 ],
[0.33202055, 0.17871413, 0.37546346],
[0.34670505, 0.4218078 , 0.39992693]], dtype=float32))}
op.update(grads_post)
print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.9993834 , 0.99957764, 0.99888885, 0.9990923 ],
[0.9996601 , 0.9994259 , 0.9994429 , 0.9993059 ],
[0.99949163, 0.99903625, 0.99939734, 0.99900746],
[0.9988641 , 0.99930525, 0.99915785, 0.99933165],
[0.999811 , 0.99887973, 0.99930376, 0.9989378 ]], dtype=float32))
b: JaxArray(DeviceArray([[-0.00064861, -0.00067702, -0.00047789],
[-0.00090613, -0.00088537, -0.00131677],
[-0.00106666, -0.00053917, -0.00135247]], dtype=float32))
Note
Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, you should add variables in Optimzier.vars()
onto the dyn_vars
in bm.jit()
.
op.vars() # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable(DeviceArray([2], dtype=int32))}
bm.optimizers.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Momentum has velocity variables
{'Constant1.step': Variable(DeviceArray([0], dtype=int32)),
'Momentum0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
'Momentum0.b_v': Variable(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))}
bm.optimizers.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Adam has more variables
{'Constant2.step': Variable(DeviceArray([0], dtype=int32)),
'Adam0.a_m': Variable(DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
'Adam0.b_m': Variable(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
'Adam0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
'Adam0.b_v': Variable(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))}
Creating a custom optimizer
If you intend to create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer
class and override the following methods:
__init__()
: init function which receives learning rate (lr
) and trainable variables (train_vars
).update(grads)
: update function to compute the updated parameters.
For example,
class CustomizeOp(bm.optimizers.Optimizer):
def __init__(self, lr, train_vars, *params, **other_params):
super(CustomizeOp, self).__init__(lr, train_vars)
# customize your initialization
def update(self, grads):
# customize your update logic
pass
Schedulers
Scheduler seeks to adjust the learning rate during training by reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.
For example, we setup an exponential decay scheduler, in which the learning rate will decay exponentially along the train step.
sc = bm.optimizers.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
plt.plot(steps, rates)
plt.xlabel('Train Step')
plt.ylabel('Learning Rate')
plt.show()
steps = bm.arange(1000)
rates = sc(steps)
show(steps, rates)

After Optimizer initialization, the learning rate self.lr
will always be an instance of bm.optimizers.Scheduler
. A scalar float learning rate initialization will result in a Constant
scheduler.
op.lr
<brainpy.math.jax.optimizers.Constant at 0x2ab375a3700>
One can get the current learning rate value by calling Scheduler.__call__(i=None)
.
If
i
is not provided, the learning rate value will be evaluated at the built-in trainingstep
.Otherwise, the learning rate value will be evaluated at the given step
i
.
op.lr()
0.001
In BrainPy, several common used learning rate schedulers are used:
Constant
ExponentialDecay
InverseTimeDecay
PolynomialDecay
PiecewiseConstant
# InverseTimeDecay scheduler
rates = bm.optimizers.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)

# PolynomialDecay scheduler
rates = bm.optimizers.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)

Creating a custom scheduler
If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler
class and override the following methods:
__init__()
: the init function.__call__(i=None)
: the learning rate value evalution.
class CustomizeScheduler(bm.optimizers.Scheduler):
def __init__(self, lr, *params, **other_params):
super(CustomizeScheduler, self).__init__(lr)
# customize your initialization
def __call__(self, i=None):
# customize your update logic
pass
Numerical Integrator
Numerical Solvers for ODEs
Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to using numerical methods for solving such differential equations.
In this section, I will illustrate how to define ordinary differential quations (ODEs), and how to define the numerical integration methods for them in BrainPy.
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
%matplotlib inline
How to define ODE functions?
BrainPy provides a convenient and intuitive way to define ODE systems. For the ODE
$$ {dx \over dt} = f_1(x, t, y, p_1)\ {dy \over dt} = f_2(y, t, x, p_2) $$
we can define this system as a Python function:
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
where t
denotes the current time, p1
and p2
which after the t
are represented as parameters needed in this system, and x
and y
passed before t
denotes the dynamical variables. In the function body, the derivative for each variable can be customized by the user’s need f1
and f2
. Finally, we return the corresponding derivatives dx
and dy
with the order the same as the variables in the function arguments.
For each variable x
or y
, it can be a scalar (var_type = bp.integrators.SCALAR_VAR
), a vector/matrix (var_type = bp.integrators.POP_VAR
), or a system (var_type = bp.integrators.SYSTEM_VAR
). Here, the “system” means that the argument x
denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:
def diff(xy, t, p1, p2):
x, y = xy
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return bm.array([dx, dy])
How to define the numerical integration for ODEs?
After the definition of ODE functions, the numerical integration of these functions are very easy in BrainPy. We just need put a decorator (bp.odeint
).
@bp.odeint
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
After wrapping the derivative function by bp.odeint
, the function becomes an instance of ODEintegrator
.
isinstance(diff, bp.ode.ODEIntegrator)
True
bp.odeint
receives several arguments:
“method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.
diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x7fa92852d700>
“dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.
diff.dt
0.1
“show_code”: bool, indicates whether show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.
@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
x_new = x + dx_k1 * dt * 1
y_new = y + dy_k1 * dt * 1
return x_new, y_new
{'f': <function diff at 0x7fa928517e50>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x7fa92853d040>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
k2_x_arg = x + dt * dx_k1 * 0.5
k2_y_arg = y + dt * dy_k1 * 0.5
k2_t_arg = t + dt * 0.5
dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
k3_x_arg = x + dt * dx_k2 * 0.5
k3_y_arg = y + dt * dy_k2 * 0.5
k3_t_arg = t + dt * 0.5
dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
k4_x_arg = x + dt * dx_k3
k4_y_arg = y + dt * dy_k3
k4_t_arg = t + dt
dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
return x_new, y_new
{'f': <function diff at 0x7fa928517a60>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa92853d910>
Two Illustrated Examples
Example 1: FitzHugh–Nagumo model
Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:
$$ \begin{split} \tau {\dot {w}}&=v+a-bw\ {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+I_{\rm {ext}} \end{split} $$
For this FHN model, we can code it in BrainPy like this:
@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
dw = (V + a - b * w) / tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,
a=0.7; b=0.8; tau=12.5; Iext=1.
the solution of the FHN model between 0 and 100 ms can be approximated by
import matplotlib.pyplot as plt
%matplotlib inline
hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
V, w = integral(V, w, t, Iext, a, b, tau)
hist_V.append(V)
plt.plot(hist_times, hist_V)
[<matplotlib.lines.Line2D at 0x7fa92846fa30>]

Example 2: Hodgkin–Huxley model
Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h
) are used for modeling the initiation and propagration of the action potential. Specificaly, they are governed by the following equations:
$$ \begin{aligned} C_{m} \frac{d V}{d t} &=-\bar{g}{\mathrm{K}} n^{4}\left(V-V{K}\right)- \bar{g}{\mathrm{Na}} m^{3} h\left(V-V{N a}\right)-\bar{g}{l}\left(V-V{l}\right)+I_{s y n} \ \frac{d m}{d t} &=\alpha_{m}(V)(1-m)-\beta_{m}(V) m \ \frac{d h}{d t} &=\alpha_{h}(V)(1-h)-\beta_{h}(V) h \ \frac{d n}{d t} &=\alpha_{n}(V)(1-n)-\beta_{n}(V) n \end{aligned} $$
In BrainPy, such dynamical system can be coded as:
@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt, dmdt, dhdt, dndt
Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:
Iext=10.; ENa=50.; EK=-77.; EL=-54.387
C=1.0; gNa=120.; gK=36.; gL=0.03
hist_times = bm.arange(0, 100, 0.01)
hist_V, hist_m, hist_h, hist_n = [], [], [], []
V, m, h, n = 0., 0., 0., 0.
for t in hist_times:
V, m, h, n = integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C)
hist_V.append(V)
hist_m.append(m)
hist_h.append(h)
hist_n.append(n)
plt.subplot(211)
plt.plot(hist_times, hist_V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(hist_times, hist_m, label='m')
plt.plot(hist_times, hist_h, label='h')
plt.plot(hist_times, hist_n, label='n')
plt.legend()
plt.show()

Provided ODE Numerical Solvers
BrainPy
provides several numerical methods for ordinary differential equations (ODEs). Specifically, we provide explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler method for ODE numerical integration.
Explicit Runge-Kutta methods for ODEs
The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:
Methods |
Keywords |
---|---|
euler |
|
midpoint |
|
heun2 |
|
ralston2 |
|
rk2 |
|
rk3 |
|
rk4 |
|
heun3 |
|
ralston3 |
|
ssprk3 |
|
ralston4 |
|
rk4_38rule |
Users can utilize these methods by specify the method
option in brainpy.odeint()
with their corresponding keyword. For example:
@bp.odeint(method='rk4')
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa9283b6100>
Or, you can directly instance your favorite integrator like:
@bp.ode.RK4
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa92852dc70>
def derivative(v, t, p):
# do something
return v
int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa9283b63a0>
Adaptive Runge-Kutta methods for ODEs
The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if $error > tol$, then replace $dt$ with $dt_{new}$ and repeat the step. Therefore, adaptive RK methods allow the varied step size. In BrainPy, the following adaptive RK methods are provided:
Methods |
keywords |
---|---|
rkf45 |
|
rkf12 |
|
rkdp |
|
ck |
|
bs |
|
heun_euler |
In default, the above methods are not adaptive, unless users provide a keyword adaptive=True
in brainpy.odeint()
. When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt
will be appended in the functional arguments. Moreover, the tolerance tol
for stepsize adjustment can also be controlled by users. Let’s take the Lorenz system as the example:
# adaptively adjust stepsize
@bp.odeint(method='rkf45',
adaptive=True, # active the "adaptive" option
tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
# should provide one more argument "dt" when using the adaptive rk method
x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)
hist_x.append(x)
hist_y.append(y)
hist_z.append(z)
hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()


Exponential Euler methods for ODEs
Finally, BrainPy provides Exponential integrators for ODEs. For you linear ODE systems, we highly recommend you to to use Exponential Euler methods.
Methods |
keywords |
---|---|
exponential_euler |
For a linear system,
$$ {dy \over dt} = A - By $$
the exponential Euler schema is given by:
$$ y(t+dt) = y(t) e^{-Bdt} + {A \over B}(1 - e^{-Bdt}) $$
As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.
In BrainPy, in order to automatically find out the linear part, we will utilize the SymPy to parse user defined functions. Therefore, ones need install sympy
first when using exponential Euler method.
What’s interesting, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. In the next, you will find that by using Exponential Euler method, the numerical step can be enlarged much to save the computation time.
$$ \begin{aligned} C_{m}{\frac {d V}{dt}}&= -\left[{\bar {g}}{\text{K}}n^{4} + {\bar {g}}{\text{Na}}m^{3}h + {\bar {g}}{l} \right] V +{\bar {g}}{\text{K}}n^{4} V_{K} + {\bar {g}}{\text{Na}}m^{3}h V{Na} + {\bar {g}}{l} V{l} + I_{syn} \ {\frac {dm}{dt}} &= \left[-\alpha _{m}(V)-\beta _{m}(V)\right]m + \alpha _{m}(V) \ {\frac {dh}{dt}} &= \left[-\alpha _{h}(V)-\beta _{h}(V)\right]h + \alpha _{h}(V) \ {\frac {dn}{dt}} &= \left[-\alpha _{n}(V)-\beta _{n}(V)\right]n + \alpha _{n}(V) \ \end{aligned} $$
Iext=10.; ENa=50.; EK=-77.; EL=-54.387
C=1.0; gNa=120.; gK=36.; gL=0.03
def derivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt, dmdt, dhdt, dndt
def run(method, Iext=10., dt=0.1):
hist_times = bm.arange(0, 100, dt)
hist_V, hist_m, hist_h, hist_n = [], [], [], []
V, m, h, n = 0., 0., 0., 0.
for t in hist_times:
V, m, h, n = method(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C)
hist_V.append(V)
hist_m.append(m)
hist_h.append(h)
hist_n.append(n)
plt.subplot(211)
plt.plot(hist_times, hist_V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(hist_times, hist_m, label='m')
plt.plot(hist_times, hist_h, label='h')
plt.plot(hist_times, hist_n, label='n')
plt.legend()
Euler Method
int1 = bp.odeint(f=derivative, method='euler', dt=0.1)
run(int1, Iext=10, dt=0.1)
<ipython-input-25-35d6bfdac53f>:2: RuntimeWarning: overflow encountered in exp
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
<ipython-input-25-35d6bfdac53f>:3: RuntimeWarning: overflow encountered in exp
beta = 4.0 * bm.exp(-(V + 65) / 18)
<ipython-input-25-35d6bfdac53f>:6: RuntimeWarning: overflow encountered in exp
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
<ipython-input-25-35d6bfdac53f>:7: RuntimeWarning: overflow encountered in exp
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
<ipython-input-25-35d6bfdac53f>:10: RuntimeWarning: overflow encountered in exp
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
<ipython-input-25-35d6bfdac53f>:11: RuntimeWarning: overflow encountered in exp
beta = 0.125 * bm.exp(-(V + 65) / 80)
<ipython-input-25-35d6bfdac53f>:4: RuntimeWarning: invalid value encountered in double_scalars
dmdt = alpha * (1 - m) - beta * m
<ipython-input-25-35d6bfdac53f>:8: RuntimeWarning: invalid value encountered in double_scalars
dhdt = alpha * (1 - h) - beta * h
<ipython-input-25-35d6bfdac53f>:12: RuntimeWarning: invalid value encountered in double_scalars
dndt = alpha * (1 - n) - beta * n

int2 = bp.odeint(f=derivative, method='euler', dt=0.02)
run(int2, Iext=10, dt=0.02)

RK4 Method
int3 = bp.odeint(f=derivative, method='rk4', dt=0.1)
run(int3, Iext=10, dt=0.1)

int4 = bp.odeint(f=derivative, method='rk4', dt=0.2)
run(int4, Iext=10, dt=0.2)
<ipython-input-25-35d6bfdac53f>:2: RuntimeWarning: overflow encountered in exp
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
<ipython-input-25-35d6bfdac53f>:3: RuntimeWarning: overflow encountered in exp
beta = 4.0 * bm.exp(-(V + 65) / 18)
<ipython-input-25-35d6bfdac53f>:6: RuntimeWarning: overflow encountered in exp
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
<ipython-input-25-35d6bfdac53f>:7: RuntimeWarning: overflow encountered in exp
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
<ipython-input-25-35d6bfdac53f>:10: RuntimeWarning: overflow encountered in exp
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
<ipython-input-25-35d6bfdac53f>:11: RuntimeWarning: overflow encountered in exp
beta = 0.125 * bm.exp(-(V + 65) / 80)
<ipython-input-25-35d6bfdac53f>:4: RuntimeWarning: invalid value encountered in double_scalars
dmdt = alpha * (1 - m) - beta * m
<ipython-input-25-35d6bfdac53f>:8: RuntimeWarning: invalid value encountered in double_scalars
dhdt = alpha * (1 - h) - beta * h
<ipython-input-25-35d6bfdac53f>:12: RuntimeWarning: invalid value encountered in double_scalars
dndt = alpha * (1 - n) - beta * n
<ipython-input-25-35d6bfdac53f>:17: RuntimeWarning: invalid value encountered in double_scalars
dVdt = (- I_Na - I_K - I_leak + Iext) / C

Exponential Euler Method
int5 = bp.odeint(f=derivative, method='exponential_euler', dt=0.2)
run(int5, Iext=10, dt=0.2)

Numerical Solvers for SDEs
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
import brainpy as bp
bp.__version__
'1.1.0'
import matplotlib.pyplot as plt
%matplotlib inline
How to define SDE functions?
For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by
$$ \begin{aligned} d X_{t}&=f\left(X_{t}, t, p_1\right) d t+g\left(X_{t}, t, p_2\right) d W_{t} \quad (1) \end{aligned} $$
where $X_t = X(t)$ is the realization of a stochastic process or random variable, $f(X_t, t)$ is the drift coefficient, $g(X_t, t)$ denotes the diffusion coefficient, the stochastic process $W_t$ is called Wiener process.
For this SDE system, we can define two Python funtions $f$ and $g$ to represent it.
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df
Same with the ODE functions, the arguments before $t$ denotes the random variables, while the arguments defined after $t$ represents the parameters. For the SDE function with scalar noise, the size of the return data $dg$ and $df$ should be the same. For example, $df \in R^d, dg \in R^d$.
However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:
$$ dX_t=f(X_t)dt+\sum_{\alpha=1}^{m}g_{\alpha }(X_t)dW_t ^{\alpha} $$
For such $m$-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of $dg$ has one more dimension. For example, $df \in R^{d}, dg \in R^{d \times m}$.
How to define the numerical integration for SDEs?
Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get
$$ \begin{aligned} X_{t}&=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} \quad (2) \end{aligned} $$
In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “$\circ$” to distinct it from the former Ito integral.
$$ \begin{aligned} d X_{t} &=f\left(X_{t}, t\right) d t+g\left(X_{t}, t\right) \circ d W_{t} \ X_{t} &=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) \circ d W_{s} \quad (3) \end{aligned} $$
The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as
$$ \begin{split} \int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} &=\lim {h \rightarrow 0} \sum{k=0}^{m-1} g\left(X_{\tau_{k}}, \tau_{k}\right)\left(W\left(t_{k+1}\right)-W\left(t_{k}\right)\right) \ \mathrm{where} \quad & h = t_{k+1} - t_{k} \ & \tau_k = (1-\lambda)t_k +\lambda t_{k+1} \end{split} $$
In the stochastic integral of the Ito SDE, $\lambda=0$, thus $\tau_k=t_k$;
In the definition of the Stratonovich integral, $\lambda=0.5$, thus $\tau_k=(t_{k+1} + t_{k}) / 2$.
In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type
in decorator bp.sdeint
. intg_type
can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type
keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.
Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
Or, it can be expressed as:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part, g=g_part, method='milstein')
However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d, m)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part,
g=g_part,
method='milstein',
intg_type=bp.integrators.STRA_SDE,
wiener_type=bp.integrators.SCALAR_WIENER)
Example: Noisy Lorenz system
Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:
$$ \begin{array}{l} \frac{d x}{dt}&=\sigma(y-x) &+ px*\xi_x \ \frac{d y}{dt}&=x(\rho-z)-y &+ py*\xi_y\ \frac{d z}{dt}&=x y-\beta z &+ pz*\xi_z \end{array} $$
sigma = 10; beta = 8/3;
rho = 28; p = 0.1
def lorenz_g(x, y, z, t):
return p * x, p * y, p * z
def lorenz_f(x, y, z, t):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
lorenz = bp.sdeint(f=lorenz_f,
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER,
dt=0.005)
hist_times = bp.math.arange(0, 50, 0.005)
hist_x, hist_y, hist_z = [], [], []
x, y, z = 1., 1., 1.
for t in hist_times:
x, y, z = lorenz(x, y, z, t)
hist_x.append(x)
hist_y.append(y)
hist_z.append(z)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
Text(0.5, 0, 'z')

Supported SDE Numerical Methods
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
Methods |
Keywords |
Ito SDE support |
Stratonovich SDE support |
Scalar Wiener support |
Vector Wiener support |
---|---|---|---|---|---|
srk1w1_scalar |
Yes |
Yes |
|||
srk2w1_scalar |
Yes |
Yes |
|||
KlPl_scalar |
Yes |
Yes |
|||
euler |
Yes |
Yes |
Yes |
Yes |
|
heun |
Yes |
Yes |
Yes |
||
milstein |
Yes |
Yes |
Yes |
Yes |
|
exponential_euler |
Yes |
Yes |
Yes |
Dynamics Simulation
Efficient Synaptic Computation
In a real project, the most of simulation time spends on the computation of the synapses. Therefore, figuring out what is the most efficient way to do synaptic computation is a necessary step to accelerate your computational project. Here, let’s take an E/I balance network as an example to illustrate how to code an efficient synaptic computation.
import brainpy as bp
import brainpy.math as bm
bm.use_backend('numpy')
%matplotlib inline
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
The E/I balance network COBA is adopted from (Vogels & Abbott, 2005) [1].
# Parameters for network structure
num = 4000
num_exc = int(num * 0.75)
num_inh = int(num * 0.25)
Neuron Model
In COBA network, each integrate-and-fire neuron is characterized by a time constant, $\tau$ = 20 ms, and a resting membrane potential, $V_{rest}$ = -60 mV. Whenever the membrane potential crosses a spiking threshold of -50 mV, an action potential is generated and the membrane potential is reset to the resting potential, where it remains clamped for a 5 ms refractory period. The membrane voltages are calculated as follows:
$$ \tau {dV \over dt} = (V_{rest} - V) + g_{exc}(E_{exc} - V) + g_{inh}(E_{inh} - V) $$
where reversal potentials are $E_{exc} = 0$ mV and $E_{inh} = -80$ mV.
# Parameters for the neuron
tau = 20 # ms
Vt = -50 # mV
Vr = -60 # mV
El = -60 # mV
ref_time = 5.0 # refractory time, ms
I = 20.
class LIF(bp.NeuGroup):
def __init__(self, size, **kwargs):
super(LIF, self).__init__(size=size, **kwargs)
# variables
self.V = bm.Variable(bm.zeros(size))
self.input = bm.Variable(bm.zeros(size))
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, V, t, Iexc):
dV = (Iexc + El - V) / tau
return dV
def update(self, _t, _dt):
for i in range(self.num):
self.spike[i] = 0.
if (_t - self.t_last_spike[i]) > ref_time:
V = self.integral(self.V[i], _t, self.input[i])
if V >= Vt:
self.V[i] = Vr
self.spike[i] = 1.
self.t_last_spike[i] = _t
else:
self.V[i] = V
self.input[i] = I
Synapse Model
In COBA network, when a neuron fires, the appropriate synaptic variable of its postsynaptic targets are increased, $g_{exc} \gets g_{exc} + \Delta g_{exc}$ for an excitatory presynaptic neuron and $g_{inh} \gets g_{inh} + \Delta g_{inh}$ for an inhibitory presynaptic neuron. Otherwise, these parameters obey the following equations:
$$ \tau_{exc} {dg_{exc} \over dt} = -g_{exc} \quad \quad (1) \ \tau_{inh} {dg_{inh} \over dt} = -g_{inh} \quad \quad (2) $$
with synaptic time constants $\tau_{exc} = 5$ ms, $\tau_{inh} = 10$ ms, $\Delta g_{exc} = 0.6$ and $\Delta g_{inh} = 6.7$.
# Parameters for the synapse
tau_exc = 5 # ms
tau_inh = 10 # ms
E_exc = 0. # mV
E_inh = -80. # mV
delta_exc = 0.6 # excitatory synaptic weight
delta_inh = 6.7 # inhibitory synaptic weight
def run_net(neu_model, syn_model):
E = neu_model(num_exc, monitors=['spike'])
E.V[:] = bm.random.randn(num_exc) * 5. + Vr
I = neu_model(num_inh, monitors=['spike'])
I.V[:] = bm.random.randn(num_inh) * 5. + Vr
E2E = syn_model(pre=E, post=E, conn=bp.connect.FixedProb(0.02),
tau=tau_exc, weight=delta_exc, E=E_exc)
E2I = syn_model(pre=E, post=I, conn=bp.connect.FixedProb(0.02),
tau=tau_exc, weight=delta_exc, E=E_exc)
I2E = syn_model(pre=I, post=E, conn=bp.connect.FixedProb(0.02),
tau=tau_inh, weight=delta_inh, E=E_inh)
I2I = syn_model(pre=I, post=I, conn=bp.connect.FixedProb(0.02),
tau=tau_inh, weight=delta_inh, E=E_inh)
net = bp.Network(E, I, E2E, E2I, I2E, I2I)
net = bm.jit(net)
t = net.run(100., report=0.1)
fig, gs = bp.visualize.get_figure(row_num=5, col_num=1, row_len=1, col_len=10)
fig.add_subplot(gs[:4, 0])
bp.visualize.raster_plot(E.mon.ts, E.mon.spike, xlim=(0, 100.), ylabel='E Group', xlabel='')
fig.add_subplot(gs[4, 0])
bp.visualize.raster_plot(I.mon.ts, I.mon.spike, xlim=(0, 100.), ylabel='I Group', show=True)
return t
Matrix-based connection
The matrix-based synaptic connection is one of the most intuitive way to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat')
(details please see Synaptic Connectivity). Each connection matrix is an array with the shape of (num_pre, num_post)
, like

Based on conn_mat
, the updating logic of the above synapses can be coded as:
class SynMat1(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynMat1, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# p1: connections
self.conn = conn(pre.size, post.size)
self.conn_mat = self.conn.requires('conn_mat')
# variables
self.g = bm.Variable(bm.zeros(self.conn_mat.shape))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
spike_on_syn = bm.expand_dims(self.pre.spike, 1) * self.conn_mat # p2
self.g[:] += spike_on_syn * self.weight # p3
self.post.input[:] += bm.sum(self.g, axis=0) * (self.E - self.post.V) # p4
In the above defined SynMat1
class, at “p1” line we requires a “conn_mat” structure for the later synaptic computation; at “p2” we get spikes for each synaptic connections according to “conn_mat” and “presynaptic spikes”; then at “p3”, the spike-triggered synaptic variables are added onto its postsynaptic targets; at final “p4” code line, all connected synaptic values are summed to get the current effective conductance by np.sum(self.g, axis=0).
Now, let’s inspect the performance of this matrix-based synapse.
t_syn_mat1 = run_net(neu_model=LIF, syn_model=SynMat1)
Compilation used 8.9806 s.
Start running ...
Run 10.0% used 11.922 s.
Run 20.0% used 24.164 s.
Run 30.0% used 36.259 s.
Run 40.0% used 48.411 s.
Run 50.0% used 60.550 s.
Run 60.0% used 72.775 s.
Run 70.0% used 85.545 s.
Run 80.0% used 98.326 s.
Run 90.0% used 110.973 s.
Run 100.0% used 123.404 s.
Simulation is done in 123.404 s.

This matrix-based synapse structure is very inefficient, because 99.9% time were wasted on the synaptic computation. We can inspect this by only running the neuron group models.
group = bm.jit(LIF(num, monitors=['spike']))
group.V[:] = bm.random.randn(num) * 5. + Vr
group.run(100., inputs=('input', 5.), report=True)
Compilation used 0.1588 s.
Start running ...
Run 100.0% used 0.027 s.
Simulation is done in 0.027 s.
0.02666616439819336
As you can see, the neuron group only spends 0.026 s to run. After normalized by the total running time 120+ s, the neuron group running only accounts for about 0.02 %.
Event-based updating
The inefficiency in the above matrix-based computation comes from the horrendous waste of time on synaptic computation. First, it is uncommon for a neuron to generate a spike; Second, in a group of neuron, the generated spikes (self.pre.spike
) are usually sparse. Therefore, at many time points, there are many zeros in self.pre.spike
, which results self.g
add many unnecessary zeros (self.g += spike_on_syn * self.weight
).
Alternatively, we can update self.g
only when the pre-synaptic neuron produces a spike event (this is called as the event-based updating method):
class SynMat2(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynMat2, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.conn_mat = self.conn.requires('conn_mat')
# variables
self.g = bm.Variable(bm.zeros(self.conn_mat.shape))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
# p1
for pre_i, spike in enumerate(self.pre.spike):
if spike:
self.g[pre_i] += self.conn_mat[pre_i] * self.weight
self.post.input[:] += bm.sum(self.g, axis=0) * (self.E - self.post.V)
Compared to SynMat1
, we replace “p2” and “p3” in SynMat1
with “p1” in SynMat2
. Now, the updating logic is only when the pre-synaptic neuron emits a spike (if spike
), the connected post-synaptic state g
will be updated (self.g[pre_i] += self.conn_mat[pre_i] * self.weight
).
t_syn_mat2 = run_net(neu_model=LIF, syn_model=SynMat2)
Compilation used 8.1212 s.
Start running ...
Run 10.0% used 5.830 s.
Run 20.0% used 11.736 s.
Run 30.0% used 17.713 s.
Run 40.0% used 23.714 s.
Run 50.0% used 29.624 s.
Run 60.0% used 35.564 s.
Run 70.0% used 41.559 s.
Run 80.0% used 47.508 s.
Run 90.0% used 53.456 s.
Run 100.0% used 59.436 s.
Simulation is done in 59.436 s.

Such event-based matrix connection boosts the running speed nearly 2 times, but it’s not good enough.
Vector-based connection
Matrix-based synaptic computation may be straightforward, but can cause severe wasted RAM memory and inefficient computation. Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need $10^8$ floats to save the synaptic state, and at each update step, you need do computation on $10^8$ floats. Actually, the number of values you really needed is only $10^7$. See, there is a huge memory waste and computing resource inefficiency.
pre_ids
and post_ids
An effective method to solve this problem is to use vector to store the connectivity between neuron groups and the corresponding synaptic states. For the above defined connectivity conn_mat
, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids,

In such a way, we only need two vectors (pre_ids
and post_ids
, each has $10^7$ floats) to store the synaptic connectivity. And, at each time step, we just need update a synaptic state vector with $10^7$ floats.
class SynVec1(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec1, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
self.num = len(self.pre_ids)
# variables
self.g = bm.Variable(bm.zeros(self.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
for syn_i in range(self.num):
# p1: update
pre_i = self.pre_ids[syn_i]
if self.pre.spike[pre_i]:
self.g[syn_i] += self.weight
# p2: output
post_i = self.post_ids[syn_i]
self.post.input[post_i] += self.g[syn_i] * (self.E - self.post.V[post_i])
In SynVec1
class, we first update the synaptic state with “p1” code block, in which the synaptic state self.g[syn_i]
is updated when the pre-synaptic neuron generates a spike (if self.pre.spike[pre_i]
); then, at “p2” code block, we output the synaptic states onto the post-synaptic neurons.
t_syn_vec1 = run_net(neu_model=LIF, syn_model=SynVec1)
Compilation used 2.8904 s.
Start running ...
Run 10.0% used 0.124 s.
Run 20.0% used 0.240 s.
Run 30.0% used 0.358 s.
Run 40.0% used 0.473 s.
Run 50.0% used 0.586 s.
Run 60.0% used 0.698 s.
Run 70.0% used 0.812 s.
Run 80.0% used 0.928 s.
Run 90.0% used 1.040 s.
Run 100.0% used 1.150 s.
Simulation is done in 1.150 s.

Great! Transform the matrix-based connection into the vector-based connection makes us get a huge speed boost. However, there also exists redundant part in SynVec1
class. This is because a pre-synaptic neuron may connect to many post-synaptic neurons and thus at each step updating we will judge a pre-synaptic neuron whether generates a spike many times (self.pre.spike[pre_i]
).
pre2syn
and post2syn
In order to solve the above problem, here we create another two synaptic structures pre2syn
and post2syn
to help us retrieve the synapse states which connected with the pre-synaptic neuron $i$ and the post-synaptic neuron $j$.
In a pre2syn
list, each pre2syn[i]
stores the synaptic state indexes projected from the pre-synaptic neuron $i$.

Similarly, we can create a post2syn
list to indicate the connections between synapses and post-synaptic neurons. For each post-synaptic neuron $j$, post2syn[j]
stores the indexes of synaptic elements which connected to the post neuron $j$.

Based on these connectivity mappings, we can define another version of synapse model by using pre2syn
and post2syn
:
class SynVec2(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec2, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre_ids, self.pre2syn, self.post2syn = self.conn.requires('pre_ids', 'pre2syn', 'post2syn')
self.num = len(self.pre_ids)
# variables
self.g = bm.Variable(bm.zeros(self.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
# p1: update
for pre_i in range(self.pre.num):
if self.pre.spike[pre_i]:
for syn_i in self.pre2syn[pre_i]:
self.g[syn_i] += self.weight
# p2: output
for post_i in range(self.post.num):
for syn_i in self.post2syn[post_i]:
self.post.input[post_i] += self.g[syn_i] * (self.E - self.post.V[post_i])
In this SynVec2
class, at “p1” code-block, we update synaptic states by the for-loop with the size of pre-synaptic number. If the pre-synaptic neuron elicits a spike self.pre.spike[pre_i]
, we will for-loop its connected synaptic states by for syn_i in self.pre2syn[pre_i]
. In such a way, we only need to judge the pre-synaptic neuron pre_i
spike state once. Similarly, at “p2” code-block, the synaptic output is also implemented with the post-synaptic neuron for-loop.
t_syn_vec2 = run_net(neu_model=LIF, syn_model=SynVec2)
Compilation used 3.2760 s.
Start running ...
Run 10.0% used 0.125 s.
Run 20.0% used 0.252 s.
Run 30.0% used 0.385 s.
Run 40.0% used 0.513 s.
Run 50.0% used 0.640 s.
Run 60.0% used 0.780 s.
Run 70.0% used 0.919 s.
Run 80.0% used 1.049 s.
Run 90.0% used 1.181 s.
Run 100.0% used 1.308 s.
Simulation is done in 1.308 s.

We only got a similar speed performance. This is because the optimization of the “update” block has run its course. Currently, the most of the running costs spend on the “output” block.
pre2post
and post2pre
Notice that for this kind of synapse model, the synaptic states $g$ onto a post-synaptic neuron can be modeled together. This is because the synaptic state evolution according to the differential equation (1) and (2) after the pre-synaptic spikes can be superposed. This means that we can declare a synaptic state self.g
with the shape of post.num
, not the shape of the synapse number.
In order to achieve this goal, we create another two synaptic structures (pre2post
and post2pre
) which establish the direct mapping between the pre-synaptic neurons and the post-synaptic neurons. pre2post
contains the connected post-synaptic neurons indexes, in which pre2post[i]
retrieves the post neuron ids projected from pre-synaptic neuron $i$. post2pre
contains the pre-synaptic neurons indexes, in which post2pre[j]
retrieves the pre-syanptic neuron ids which project to post-synaptic neuron $j$.


class SynVec3(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec3, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre2post = self.conn.requires('pre2post')
# variables
self.g = bm.Variable(bm.zeros(post.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
# p1: update
for pre_i in range(self.pre.num):
if self.pre.spike[pre_i]:
for post_i in self.pre2post[pre_i]:
self.g[post_i] += self.weight
# p2: output
self.post.input[:] += self.g * (self.E - self.post.V)
In SynVec3
class, we require a pre2post
structure, and then at “p1” code-block, when the pre-synaptic neuron pre_i
emits a spike, the connected post-synaptic neurons’ state self.g[post_i]
will increase the conductance.
t_syn_vec3 = run_net(neu_model=LIF, syn_model=SynVec3)
Compilation used 4.1941 s.
Start running ...
Run 10.0% used 0.006 s.
Run 20.0% used 0.014 s.
Run 30.0% used 0.020 s.
Run 40.0% used 0.029 s.
Run 50.0% used 0.038 s.
Run 60.0% used 0.045 s.
Run 70.0% used 0.051 s.
Run 80.0% used 0.059 s.
Run 90.0% used 0.067 s.
Run 100.0% used 0.075 s.
Simulation is done in 0.075 s.

Yeah, the running speed gets a huge boosting, which demonstrates the super effectiveness of this kind of synaptic computation.
pre_slice
and post_slice
However, it is not perfect. This is because pre2syn
, post2syn
, pre2post
and post2pre
are all the data with the list
type, which can not be directly deployed to GPU devices. What the GPU device prefers are only arrays.
To solve this problem, we, instead, can create a post_slice
connection structure which stores the start and the end position on the synpase state for each connected post-synaptic neuron $j$. post_slice
can be implemented by aligning the pre ids according to the sequential post id $0, 1, 2, …$ (look the following illustrating figure). For each post neuron $j$, start, end = post_slice[j]
retrieves the start/end position of the connected synapse states.

Therefore, an alternative updating logic of pre2syn
and post2syn
(in SynVec2
class) can be replaced by post_slice
and pre_ids
:
class SynVec4(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec4, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre_ids, self.post_slice = self.conn.requires('pre_ids', 'post_slice')
self.num = len(self.pre_ids)
# variables
self.g = bm.Variable(bm.zeros(self.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
# p1: update
for syn_i in range(self.num):
pre_i = self.pre_ids[syn_i]
if self.pre.spike[pre_i]:
self.g[syn_i] += self.weight
# p2: output
for post_i in range(self.post.num):
start, end = self.post_slice[post_i]
# for syn_i in range(start, end):
self.post.input[post_i] += self.g[start: end].sum() * (self.E - self.post.V[post_i])
t_syn_vec4 = run_net(neu_model=LIF, syn_model=SynVec4)
Compilation used 3.3534 s.
Start running ...
Run 10.0% used 0.103 s.
Run 20.0% used 0.207 s.
Run 30.0% used 0.314 s.
Run 40.0% used 0.432 s.
Run 50.0% used 0.536 s.
Run 60.0% used 0.635 s.
Run 70.0% used 0.739 s.
Run 80.0% used 0.842 s.
Run 90.0% used 0.946 s.
Run 100.0% used 1.050 s.
Simulation is done in 1.050 s.

Similarly, a connection mapping pre_slice
can also be implemented, in which for each pre-synaptic neuron $i$, start, end = pre_slice[i]
retrieves the start/end position of the connected synapse states.

Moreover, an alternative updating logic of pre2post
(in SynVec3
class) can also be replaced by pre_slice
and post_ids
:
class SynVec5(bp.TwoEndConn):
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec5, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre_slice, self.post_ids = self.conn.requires('pre_slice', 'post_ids')
# variables
self.g = bm.Variable(bm.zeros(post.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
self.g[:] = self.integral(self.g, _t)
# p1: update
for pre_i in range(self.pre.num):
if self.pre.spike[pre_i]:
start, end = self.pre_slice[pre_i]
for post_i in self.post_ids[start: end]:
self.g[post_i] += self.weight
# p2: output
self.post.input[:] += self.g * (self.E - self.post.V)
t_syn_vec5 = run_net(neu_model=LIF, syn_model=SynVec5)
Compilation used 4.2100 s.
Start running ...
Run 10.0% used 0.005 s.
Run 20.0% used 0.011 s.
Run 30.0% used 0.017 s.
Run 40.0% used 0.024 s.
Run 50.0% used 0.031 s.
Run 60.0% used 0.038 s.
Run 70.0% used 0.045 s.
Run 80.0% used 0.051 s.
Run 90.0% used 0.057 s.
Run 100.0% used 0.064 s.
Simulation is done in 0.064 s.

Speed comparison
In this tutorial, we introduce nine different synaptic connection structures:
conn_mat : The connection matrix with the shape of
(pre_num, post_num)
.pre_ids: The connected pre-synaptic neuron indexes, a vector with the shape pf
syn_num
.post_ids: The connected post-synaptic neuron indexes, a vector with the shape pf
syn_num
.pre2syn: A list (with the length of
pre_num
) contains the synaptic indexes connected by each pre-synaptic neuron.pre2syn[i]
denotes the synapse ids connected by the pre-synaptic neuron $i$.post2syn: A list (with the length of
post_num
) contains the synaptic indexes connected by each post-synaptic neuron.post2syn[j]
denotes the synapse ids connected by the post-synaptic neuron $j$.pre2post: A list (with the length of
pre_num
) contains the post-synaptic indexes connected by each pre-synaptic neuron.pre2post[i]
retrieves the post neurons connected by the pre neuron $i$.post2pre: A list (with the length of
post_num
) contains the pre-synaptic indexes connected by each post-synaptic neuron.post2pre[j]
retrieves the pre neurons connected by the post neuron $j$.pre_slice: A two dimensional matrix with the shape of
(pre_num, 2)
stores the start and end positions on the synapse state for each connected pre-synaptic neuron $i$ .post_slice: A two dimensional matrix with the shape of
(post_num, 2)
stores the start and end positions on the synapse state for each connected post-synaptic neuron $j$ .
We illustrate their efficiency by a spare randomly connected E/I balance network COBA [1]. We summarize their speed in the following comparison figure:
names = ['mat 1', 'mat 2', 'pre_ids', 'pre2syn', 'pre2post', 'post_slice', 'pre_slice']
times = [t_syn_mat1, t_syn_mat2, t_syn_vec1, t_syn_vec2, t_syn_vec3, t_syn_vec4, t_syn_vec5]
xs = list(range(len(times)))
def autolabel(rects):
"""Attach a text label above each bar in *rects*, displaying its height."""
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.3f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 0.5), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom')
fig, gs = bp.visualize.get_figure(1, 1, 5, 8)
ax = fig.add_subplot(gs[0, 0])
rects = ax.bar(xs, times)
ax.set_xticks(xs)
ax.set_xticklabels(names)
ax.set_yscale('log')
plt.ylabel('Running Time [s]')
autolabel(rects)

However, the speed comparison presented here does not mean that the vector-based connection is always better than the matrix-based connection. Vector-based synaptic model is well suitable to simulate spare connections. Whereas the matrix-based synaptic model is best to solve problems for dense connections, such like all-to-all connection.
Synaptic computing in JAX backend
The above examples are all illustrated under the ‘numpy’ backend in BrainPy. However, the JIT compilation in ‘jax’ backend is much more constrained. Specifically, JAX transformations only support tensors. Therefore, for your sparse synaptic connections, we highly recommend you to use pre_ids
and post_ids
connectivities (see the above illustration).
The core problem of synaptic computation on JAX backend is how to convert values among different shape of tensors. Specifically, in the above exponential synapse model, we have three kinds of tensor shapes (see the following figure): tensors with the dimension of pre-synaptic group, tensors of the dimension of post-synaptic group, and tensors with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.
Here BrainPy provides two operators brainpy.math.pre2syn
and brainpy.math.syn2post
to convert vectors among different dimensions.
brainpy.math.pre2syn()
receives two arguments:pre_values
(the variable of the pre-synaptic dimension) andpre_ids
(the connected pre-synaptic neuron index).brainpy.math.syn2post()
receives three arguments:syn_values
(the variable with the synaptic size),post_ids
(the connected post-synaptic neuron index) andpost_num
(the number of the post-synaptic neurons).
Based on these two operators, we can define the Exponential Synapse in JAX backend as:
class SynVecJax(bp.TwoEndConn):
target_backend = 'jax'
def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
super(SynVec1, self).__init__(pre=pre, post=post, **kwargs)
# parameters
self.tau = tau
self.weight = weight
self.E = E
# connections
self.conn = conn(pre.size, post.size)
self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
self.num = len(self.pre_ids)
# variables
self.g = bm.Variable(bm.zeros(self.num))
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, g, t):
dg = - g / self.tau
return dg
def update(self, _t, _dt):
syn_sps = bm.pre2syn(self.pre.spike, self.pre_ids)
self.g[:] = self.integral(self.g, _t) + syn_sps
post_g = bm.syn2post(self.g, self.post_ids, self.post.num)
self.post.input += post_g * (self.E - self.post.V)
More synapse models please see BrainModels.
References:
[1] Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
Synaptic Connectivity
Contents
Build-in regular connections
Build-in random connections
Customize your connections
BrainPy provides several commonly used connection methods in brainpy.connect
module (see the follows). They are all inherited from the base class brainpy.connect.Connector
. Users can also customize their synaptic connectivity by the class inheritance.
import brainpy as bp
import numpy as np
import matplotlib.pyplot as plt
Build-in regular connections
brainpy.connect.One2One
The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
brainpy.connect.All2All
All neurons of the post-synaptic population form connections with all
neurons of the pre-synaptic population (dense connectivity). Users can
choose whether connect the neurons at the same position
(include_self=True or False
).

conn = bp.connect.All2All(include_self=False)
brainpy.connect.GridFour
GridFour
is the four nearest neighbors connection. Each neuron connect to its
nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
brainpy.connect.GridEight
GridEight
is eight nearest neighbors connection. Each neuron connect to its
nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
brainpy.connect.GridN
GridN
is also a nearest neighbors connection. Each neuron connect to its
nearest $2N \cdot 2N$ neurons.

conn = bp.connect.GridN(N=2, include_self=False)
Build-in random connections
brainpy.connect.FixedProb
For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=1234)
brainpy.connect.FixedPreNum
Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

conn = bp.connect.FixedPreNum(num=10, include_self=True, seed=1234)
brainpy.connect.FixedPostNum
Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

conn = bp.connect.FixedPostNum(num=10, include_self=True, seed=1234)
brainpy.connect.GaussianProb
Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.
Specifically,
$$ p=\exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right) $$
where $(x, y)$ is the position of the pre-synaptic neuron and $(x_c,y_c)$ is the position of the post-synaptic neuron.
For example, in a $30 \textrm{x} 30$ two-dimensional networks, when $\beta = \frac{1}{2\sigma^2} = 0.1$, the connection pattern is shown as the follows:

conn = bp.connect.GaussianProb(sigma=0.2, p_min=0.01, normalize=True, include_self=True, seed=1234)
brainpy.connect.GaussianWeight
Builds a Gaussian connection pattern between the two populations, where the weights decay with gaussian function.
Specifically,
$$w(x, y) = w_{max} \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right)$$
where $(x, y)$ is the position of the pre-synaptic neuron (normalized to [0,1]) and $(x_c,y_c)$ is the position of the post-synaptic neuron (normalized to [0,1]), $w_{max}$ is the maximum weight. In order to void creating useless synapses, $w_{min}$ can be set to restrict the creation of synapses to the cases where the value of the weight would be superior to $w_{min}$. Default is $0.01 w_{max}$.
def show_weight(pre_ids, post_ids, weights, geometry, neu_id):
height, width = geometry
ids = np.where(pre_ids == neu_id)[0]
post_ids = post_ids[ids]
weights = weights[ids]
X, Y = np.arange(height), np.arange(width)
X, Y = np.meshgrid(X, Y)
Z = np.zeros(geometry)
for id_, weight in zip(post_ids, weights):
h, w = id_ // width, id_ % width
Z[h, w] = weight
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap=plt.cm.coolwarm, linewidth=0, antialiased=False)
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
conn = bp.connect.GaussianWeight(sigma=0.1, w_max=1., w_min=0.01,
normalize=True, include_self=True)
pre_geom = post_geom = (40, 40)
conn(pre_geom, post_geom)
pre_ids = conn.pre_ids
post_ids = conn.post_ids
weights = conn.weights
show_weight(pre_ids, post_ids, weights, pre_geom, 820)

brainpy.connect.DOG
Builds a Difference-Of-Gaussian (dog) connection pattern between the two populations.
Mathematically,
$$ w(x, y) = w_{max}^+ \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma_+^2}\right) - w_{max}^- \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma_-^2}\right) $$
where weights smaller than $0.01 * abs(w_{max} - w_{min})$ are not created and self-connections are avoided by default (parameter allow_self_connections).
dog = bp.connect.DOG(sigmas=(0.08, 0.15), ws_max=(1.0, 0.7), w_min=0.01,
normalize=True, include_self=True)
h = 40
pre_geom = post_geom = (h, h)
dog(pre_geom, post_geom)
pre_ids = dog.pre_ids
post_ids = dog.post_ids
weights = dog.weights
show_weight(pre_ids, post_ids, weights, (h, h), h * h // 2 + h // 2)

brainpy.connect.SmallWorld
SmallWorld
is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:
$$ L\propto \log N $$
[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.
Currently, SmallWorld
only support a one-dimensional network with the ring structure. It receives four settings:
num_neighbor
: the number of the nearest neighbors to connect.prob
: the probability of rewiring each edge.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.include_self
: whether allow to connect to itself.
conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
brainpy.connect.ScaleFreeBA
ScaleFreeBA
is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA
receives the following settings:
m
: Number of edges to attach from a new node to existing nodes.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.
conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
brainpy.connect.ScaleFreeBADual
ScaleFreeBADual
is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:
p
: The probability of attaching $m_1$ edges (as opposed to $m_2$ edges).m1
: Number of edges to attach from a new node to existing nodes with probability $p$.m2
: Number of edges to attach from a new node to existing nodes with probability $1-p$.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.
conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
brainpy.connect.PowerLaw
PowerLaw
is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:
m
: the number of random edges to add for each new nodep
: Probability of adding a triangle after adding a random edgedirected
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.
conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)
Customize your connections
BrainPy also allows you to customize your model connections. What need users do is only two aspects:
Your connection class should inherit
brainpy.connect.Connector
.Initialize the
conn_mat
orpre_ids
+post_ids
synaptic structures.Provide
num_pre
andnum_post
information.
In such a way, based on this customized connection class, users can generate any other synaptic structures (such like pre2post
, pre2syn
, pre_slice_syn
, etc.) easily.
Here, let’s take a simple connection as an example. In this example, we create a connection method which receives users’ handful index projection.
class IndexConn(bp.connect.Connector):
def __init__(self, i, j):
super(IndexConn, self).__init__()
# initialize the class via "pre_ids" and "post_ids"
self.pre_ids = bp.ops.as_tensor(i)
self.post_ids = bp.ops.as_tensor(j)
def __call__(self, pre_size, post_size):
self.num_pre = bp.size2len(pre_size) # this is ncessary when create "pre2post" ,
# "pre2syn" etc. structures
self.num_post = bp.size2len(post_size) # this is ncessary when create "post2pre" ,
# "post2syn" etc. structures
return self
Let’s try to use it.
conn = IndexConn(i=[0, 1, 2], j=[0, 0, 0])
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
array([[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
conn.requires('pre2post')
[array([0]),
array([0]),
array([0]),
array([], dtype=int32),
array([], dtype=int32)]
conn.requires('pre2syn')
[array([0]),
array([1]),
array([2]),
array([], dtype=int32),
array([], dtype=int32)]
conn.requires('pre_slice_syn')
array([[0, 1],
[1, 2],
[2, 3],
[3, 3],
[3, 3]])
Chaoming Wang (adaduo@outlook.com)
Update at 2021.04.16
Monitors and Inputs
BrainPy has a systematic naming system. Any model in BrainPy have a unique name. Thus, nodes, integrators, and variables can be easily accessed in a huge network. Based on this naming system, BrainPy provides a set of convenient monitoring and input supports. In this section, we are going to talk about this.
import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt
Monitors
In BrainPy, any instance of brainpy.DynamicalSystem
has a build-in monitor. Users can set up the monitor when initializing the brain object. For example, if you have the following HH
neuron model,
class HH(bp.NeuGroup):
def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0,
gNa=120., gK=36., gL=0.03, V_th=20., **kwargs):
super(HH, self).__init__(size=size, **kwargs)
# parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.C = C
self.gNa = gNa
self.gK = gK
self.gL = gL
self.V_th = V_th
# variables
self.V = bm.Variable(bm.ones(self.num) * -65.)
self.m = bm.Variable(bm.ones(self.num) * 0.5)
self.h = bm.Variable(bm.ones(self.num) * 0.6)
self.n = bm.Variable(bm.ones(self.num) * 0.32)
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
self.integral = bp.odeint(self.derivative, method='exponential_euler')
def derivative(self, V, m, h, n, t, Iext):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
I_Na = (self.gNa * m ** 3 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
return dVdt, dmdt, dhdt, dndt
def update(self, _t, _dt):
V, m, h, n = self.integral(self.V, self.m, self.h, self.n, _t, self.input)
self.spike[:] = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V[:] = V
self.m[:] = m
self.h[:] = h
self.n[:] = n
self.input[:] = 0
The monitor can be set up when users create a HH
neuron group.
First method is to initialize a monitor is using a list/tuple of strings.
# set up a monitor using a list of str
group1 = HH(size=10, monitors=['V', 'spike'])
type(group1.mon)
brainpy.simulation.monitor.Monitor
The initialized monitor is an instance of brainpy.Monitor
. Therefore, users can also directly use Monitor
class to initialize a monitor.
# set up a monitor using brainpy.Monitor
group2 = HH(size=10, monitors=bp.Monitor(variables=['V', 'spike']))
Once we call the .run()
function in the model, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using [model_name].mon.[variable_name]
. The history time [model_name].mon.ts
will also be generated after the model finishes its running. Let’s see an example.
group1.run(100., inputs=('input', 10))
bp.visualize.line_plot(group1.mon.ts, group1.mon.V, show=True)

The monitor in group1
has recorded the evolution of V
. Therefore, it can be accessed by group1.mon.V
or equivalently group1.mon['V']
. Similarly, the recorded trajectory of variable spike
can also be obtained through group1.mon.spike
.
group1.mon.spike
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
The mechanism of monitors
We want to record HH.V
and HH.spike
, why we define monitors=['V', 'spike']
during HH
initialization is successful? How does brainpy.Monitor
recognize what variables I want to trace?
Actually, given the monitor targets, BrainPy, first of all, check whether this key is the attribute of the node which defines this monitor key. For monitor targets 'V'
and 'spike'
, it is really the attributes of HH
model. However, if not, BrainPy first check whether the key’s host (brainpy.DynamicalSystem
class object) can be accessible in .nodes()
, then check whether the host has the specified variable. For example, we define a network, and define the monitor target by the absolute path.
net = bp.Network(HH(size=10, name='X'),
HH(size=20, name='Y'),
HH(size=30),
monitors=['X.V', 'Y.spike'])
net.build() # it's ok
In the above net
, there are HH
instances named as “X” and “Y”. Therefore, trying to monitor “X.V” and “Y.spike” is successful.
However, in the following example, node named with “Z” is not accessible in the generated net
. Therefore the monitoring setup failed.
z = HH(size=30, name='Z')
net = bp.Network(HH(size=10), HH(size=20), monitors=['Z.V'])
# node "Z" can not be accessed in 'net.nodes()'
try:
net.build()
except Exception as e:
print(type(e).__name__, ":", e)
BrainPyError : Cannot find target Z.V in monitor of <brainpy.simulation.brainobjects.network.Network object at 0x7fc14c454610>, please check.
Note
BrainPy only supports to monitor Variable. This is because BrainPy assumes monitoring Variable’s trajectory is meaningful, because they are dynamically changed, and others not marked as Variable will be compiled as constants.
try:
HH(size=1, monitors=['gNa']).build()
except Exception as e:
print(type(e).__name__, ":", e)
BrainPyError : "gNa" in <__main__.HH object at 0x7fc14c3d1f40> is not a dynamically changed Variable, its value will not change, we cannot monitor its trajectory.
Note
The monitors in BrainPy only record the flattened tensor values. This means if your target variable is a matrix with the shape of (N, M)
, the resulting trajectory value in the monitor after running T
times will be a tensor with the shape of (T, N x M)
.
class MatrixVarModel(bp.DynamicalSystem):
def __init__(self, **kwargs):
super(MatrixVarModel, self).__init__(**kwargs)
self.a = bm.Variable(bm.zeros((4, 4)))
def update(self, _t, _dt):
self.a += 0.01
duration = 10
model = MatrixVarModel(monitors=['a'])
model.run(duration)
print(f'The expected shape of "model.mon.a" is: {(int(duration/bm.get_dt()), model.a.size)}')
print(f'The actual shape of "model.mon.a" is: {model.mon.a.shape}')
The expected shape of "model.mon.a" is: (100, 16)
The actual shape of "model.mon.a" is: (100, 16)
Monitor variables at the selected index
Sometimes, we do not always take care of the all the content in a variable. We may be only interested in the values at the selected index. Moreover, for a huge network with a long time simulation, monitors will be a big part to consume RAM. So, only monitoring variables at the selected index will be a good solution. Fortunately, BrainPy supports to monitor a part of elements in a Variable with the format of tuple/dict like this:
group3 = HH(
size=10,
monitors=[
'V', # monitor all values of Variable 'V'
('spike', [1, 2, 3]), # monitor values of Variable at index of [1, 2, 3]
]
)
group3.run(100., inputs=('input', 10.))
print(f'The monitor shape of "V" is (run length, variable size) = {group3.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group3.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)
Or, we can use a dictionary to specify the interested index of the variable:
group4 = HH(
size=10,
monitors={'V': None, # 'None' means all values will be monitored
'spike': [1, 2, 3]} # specify the interested index
)
group4.run(100., inputs=('input', 10.))
print(f'The monitor shape of "V" is (run length, variable size) = {group4.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group4.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)
Also, we can directly instantiate brainpy.Monitor
class:
group5 = HH(
size=10,
monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])])
)
group5.run(100., inputs=('input', 10.))
print(f'The monitor shape of "V" is (run length, variable size) = {group5.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group5.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)
group6 = HH(
size=10,
monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]})
)
group6.run(100., inputs=('input', 10.))
print(f'The monitor shape of "V" is (run length, variable size) = {group5.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group5.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)
Note
When users want to record a small part of a variable whose dimension > 1, due to brainpy.Monitor
records a flattened tensor variable, they must provide the index positions at the flattened tensor.
Monitor variables with a customized period
In a long simulation with a small time step dt
, what we take care about is the trend of the variable evolution, not the exact values at each time point (especially when dt
is very small). For this scenario, we can initialize the monitors with the intervals
item specification:
group7 = HH(
size=10,
monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]},
intervals={'V': None, 'spike': 1.}) # in 1 ms, we record 'spike' only once
)
# The above instantiation is equivalent to:
#
# group7 = HH(
# size=10, monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])],
# intervals=[None, 1.])
# )
In this example, we monitor “spike” variables at the index of [1, 2, 3] for each 1 ms
.
group7.run(100., inputs=('input', 10.))
print(f'The monitor shape of "V" = {group7.mon.V.shape}')
print(f'The monitor shape of "spike" = {group7.mon.spike.shape}')
The monitor shape of "V" = (1000, 10)
The monitor shape of "spike" = (99, 3)
It’s worthy to note that for the monitor variable [variable_name]
with a non-none intervals
specification, a corresponding time item [variable_name].t
will be generated in the monitor. This is because it’s time trajectory will be different from the default time trajectory.
print('The shape of ["spike"]: ', group7.mon['spike'].shape)
print('The shape of ["spike.t"]: ', group7.mon['spike.t'].shape)
print('group7.mon["spike.t"]: ', group7.mon["spike.t"])
The shape of ["spike"]: (99, 3)
The shape of ["spike.t"]: (99,)
group7.mon["spike.t"]: [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. 89. 90.
91. 92. 93. 94. 95. 96. 97. 98. 99.]
Inputs
BrainPy also provides inputs
operation for each instance of brainpy.DynamicalSystem
. It should be carried out during calling the .run(..., inputs=xxx)
function.
The aim of inputs
is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording. inputs
should have the format like (target, value, [type, operation])
, where
target
is the target variable to inject the input.value
is the input value. It can be a scalar, a tensor, or a iterable object/function.type
is the type of the input value. It support two types of input:fix
anditer
.operation
is the input operation on the target variable. It should be set as one of{ + , - , * , / , = }
, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated asval = val + input
.
You can also give multiple inputs for different target variables, like:
inputs=[(target1, value1, [type1, op1]),
(target2, value2, [type2, op2]),
... ]
The mechanism of inputs
The mechanism of inputs
is the same with monitors
(see The mechanism of monitors
). BrainPy first check whether user specified target
can be accessed by the relative path.
If not, BrainPy separate the host name and the variable name, and further check whether the host name is defined in the .node()
and whether the variable name can be accessed by the retrieved host. Therefore, in a input setting, the target
can be set with the absolute or relative path. For example, in the below network model,
class Model(bp.DynamicalSystem):
def __init__(self, num_sizes, **kwargs):
super(Model, self).__init__(**kwargs)
self.l1 = HH(num_sizes[0], name='L')
self.l2 = HH(num_sizes[1])
self.l3 = HH(num_sizes[2])
def update(self, _t, _dt):
self.l1.update(_t, _dt)
self.l2.update(_t, _dt)
self.l3.update(_t, _dt)
model = Model([10, 20, 30])
model.run(100, inputs=[('L.V', 2.0), # access with the absolute path
('l2.V', 1), # access with the relative path
])
0.7689826488494873
inputs
supports two types of data: fix
and iter
. The first one means that the data is static; the second one denotes the data can be iterable, no matter the input value is a tensor or a function. Note, ‘iter’ type must be explicitly stated.
# a tensor
model.run(100, inputs=('L.V', bm.ones(1000) * 2., 'iter'))
0.7576150894165039
# a function
def current():
while True: yield 2.
model.run(100, inputs=('L.V', current(), 'iter'))
0.767667293548584
Current construction functions
Inputs are common in a computational experiment. Also, we need various kind of inputs. In BrainPy, we provide several convenient input functions to help users construct input currents.
section_input()
brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input()
(see below).
Sometimes, we need input currents with different values in different periods. For example, if you want to get an input in which 0-100 ms is zero, 100-400 ms is value 1., and 400-500 ms is zero, then, you can define:
current, duration = bp.inputs.section_input(values=[0, 1., 0.],
durations=[100, 300, 100],
return_length=True)
def show(current, duration, title):
ts = np.arange(0, duration, 0.1)
plt.plot(ts, current)
plt.title(title)
plt.xlabel('Time [ms]')
plt.ylabel('Current Value')
plt.show()
show(current, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')

constant_input()
brainpy.inputs.constant_input() function helps you to format constant currents in several periods.
For the input created above, we can define it again with constant_input()
by:
current, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])
show(current, duration, '[(0, 100), (1, 300), (0, 100)]')

Another example is this:
current, duration = bp.inputs.constant_input([(-1, 10), (1, 3), (3, 30), (-0.5, 10)], dt=0.1)
show(current, duration, '[(-1, 10), (1, 3), (3, 30), (-0.5, 10)]')

spike_input()
brainpy.inputs.spike_input() helps you to construct an input like a series of short-time spikes. It receives the following settings:
sp_times
: The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.sp_lens
: The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same withsp_times
.sp_sizes
: The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same withsp_times
.duration
: The total current duration.dt
: The time step precision. The default is None (will be initialized as the defaultdt
step).
For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, and each spike lasts 1 ms and the spike current is 0.5, then you can use the following funtions:
current = bp.inputs.spike_input(
sp_times=[10, 20, 30, 200, 300],
sp_lens=1., # can be a list to specify the spike length at each point
sp_sizes=0.5, # can be a list to specify the spike current size at each point
duration=400.)
show(current, 400, 'Spike Input Example')

ramp_input()
brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:
c_start
: The minimum (or maximum) current size.c_end
: The maximum (or minimum) current size.duration
: The total duration.t_start
: The ramped current start time-point.t_end
: The ramped current end time-point. Default is the None.dt
: The current precision.
We illustrate the usage of brainpy.inputs.ramp_input()
by two examples.
In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (1000 ms).
duration = 1000
current = bp.inputs.ramp_input(0, 1, duration)
show(current, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
r'$t_{start}$=0, $t_{end}$=None' % (duration))

In the second example, we increase the current size from 0. to 1. from the 200 ms to 800 ms.
duration, t_start, t_end = 1000, 200, 800
current = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)
show(current, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))

General property of current functions
There are several general properties for input construction functions.
Property 1: All input functions can automatically broadcast the current shapes, if they are heterogenous among different periods. For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example,
current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
durations=[100, 300, 100])
current.shape
(5000, 3, 10)
Property 2: Every input function receives a dt
specification. If dt
is not provided, input functions will use the default dt
in the whole BrainPy system.
bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.02).shape
(3000,)
bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.2).shape
(300,)
# the default 'dt' in 0.1
bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30]).shape
(600,)
Dynamics Training
Build Artificial Neural Networks
Artificial neural networks in BrainPy are used to build dynamical systems. Here we only talk about how to build a neural network and how to train it.
The brainpy.simulation.layers module provides various classes representing the layers of a neural network. All of them are subclasses of the brainpy.simulation.layers.Module
base class.
import brainpy as bp
bp.set_platform('cpu')
import brainpy.simulation.layers as nn
import brainpy.math.jax as bm
bp.math.use_backend('jax')
Creating a layer
A layer can be created as an instance of a brainpy.layers.Module
subclass. For example, a dense layer can be created as follows:
l = nn.Dense(num_hidden=100, num_input=128)
type(l)
brainpy.simulation.layers.dense.Dense
This will create a dense layer with 100 units, connected to another input layer with 128 dimension.
Creating a network
Chaining layer instances together like this will allow you to specify your desired network structure.
This can be done with inheritance from brainpy.layers.Module
,
class MLP(nn.Module):
def __init__(self, n_in, n_l1, n_l2, n_out):
super(MLP, self).__init__()
self.l1 = nn.Dense(num_hidden=n_l1, num_input=n_in)
self.l2 = nn.Dense(num_hidden=n_l2, num_input=n_l1)
self.l3 = nn.Dense(num_hidden=n_out, num_input=n_l2)
def update(self, x):
x = bm.relu(self.l1(x))
x = bm.relu(self.l2(x))
x = self.l3(x)
return x
mlp1 = MLP(10, 50, 100, 2)
Or using brainpy.layers.Sequential
,
mlp2 = nn.Sequential(
l1=nn.Dense(num_hidden=50, num_input=10),
r1=nn.Activation('relu'),
l2=nn.Dense(num_hidden=100, num_input=50),
r2=nn.Activation('relu'),
l3=nn.Dense(num_hidden=2, num_input=100),
)
Naming a layer
For convenience, you can name a layer by specifying the name keyword argument:
l_hidden = nn.Dense(num_hidden=50, num_input=10, name='hidden_layer')
Initializing parameters
Many types of layers, such as brainpy.layers.Dense
, have trainable parameters. These are referred to by short names that match the conventions used in modern deep learning literature. For example, a weight matrix will usually be called w, and a bias vector will usually be b.
When creating a layer with trainable parameters, TrainVar
will be created for them and initialized automatically. You can optionally specify your own initialization strategy by using keyword arguments that match the parameter variable names. For example:
l = nn.Dense(num_hidden=50, num_input=10, w=bp.initialize.Normal(0.01))
The weight matrix w of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).
There are several ways to manually initialize parameters:
Tensors
If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:
w = bm.random.normal(0, 0.01, size=(10, 50))
nn.Dense(num_hidden=50, num_input=10, w=w)
<brainpy.simulation.layers.dense.Dense at 0x23cff9bb910>
callable
If a callable is provided (e.g. a function or a brainpy.initialize.Initializer
instance), the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:
nn.Dense(num_hidden=50, num_input=10, w=bp.initialize.Normal(0.01))
<brainpy.simulation.layers.dense.Dense at 0x23cff9bf2b0>
Or, using a custom initialization function:
def init_w(shape):
return bm.random.normal(0, 0.01, shape)
nn.Dense(num_hidden=50, num_input=10, w=init_w)
<brainpy.simulation.layers.dense.Dense at 0x23cff9ac670>
Some types of parameter variables can also be set to None
at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:
nn.Dense(num_hidden=50, num_input=10, b=None)
<brainpy.simulation.layers.dense.Dense at 0x23cff99fa30>
Setup a training
Here, we show an example to train MLP to classify the MNIST images.
import numpy as np
import tensorflow as tf
# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
num_train, num_test = X_train.shape[0], X_test.shape[0]
num_dim = bp.tools.size2num(X_train.shape[1:])
X_train = np.asarray(X_train.reshape((num_train, num_dim)) / 255.0, dtype=bm.float_)
X_test = np.asarray(X_test.reshape((num_test, num_dim)) / 255.0, dtype=bm.float_)
Y_train = np.asarray(Y_train.flatten(), dtype=bm.float_)
Y_test = np.asarray(Y_test.flatten(), dtype=bm.float_)
model = MLP(n_in=num_dim, n_l1=256, n_l2=128, n_out=10)
opt = bm.optimizers.Momentum(lr=1e-3, train_vars=model.train_vars())
gv = bm.grad(lambda X, Y: bm.losses.cross_entropy_loss(model(X), Y),
dyn_vars=model.vars(),
grad_vars=model.train_vars(),
return_value=True)
@bm.jit
@bm.function(nodes=(model, opt))
def train(x, y):
grads, loss = gv(x, y)
opt.update(grads=grads)
return loss
predict = bm.jit(lambda X: bm.softmax(model(X)), dyn_vars=model.vars())
# Training
num_batch = 128
for epoch in range(30):
# Train
loss = []
sel = np.arange(len(X_train))
np.random.shuffle(sel)
for it in range(0, X_train.shape[0], num_batch):
l = train(X_train[sel[it:it + num_batch]], Y_train[sel[it:it + num_batch]])
loss.append(l)
# Eval
test_predictions = predict(X_test).argmax(1)
accuracy = np.array(test_predictions).flatten() == Y_test
print(f'Epoch {epoch + 1:4d} Train Loss {np.mean(loss):.3f} Test Accuracy {100 * np.mean(accuracy):.3f}')
Epoch 1 Train Loss 1.212 Test Accuracy 86.410
Epoch 2 Train Loss 0.467 Test Accuracy 89.810
Epoch 3 Train Loss 0.367 Test Accuracy 90.670
Epoch 4 Train Loss 0.325 Test Accuracy 91.470
Epoch 5 Train Loss 0.298 Test Accuracy 92.220
Epoch 6 Train Loss 0.278 Test Accuracy 92.890
Epoch 7 Train Loss 0.261 Test Accuracy 93.140
Epoch 8 Train Loss 0.248 Test Accuracy 93.530
Epoch 9 Train Loss 0.235 Test Accuracy 93.810
Epoch 10 Train Loss 0.224 Test Accuracy 94.010
Epoch 11 Train Loss 0.214 Test Accuracy 94.100
Epoch 12 Train Loss 0.205 Test Accuracy 94.350
Epoch 13 Train Loss 0.196 Test Accuracy 94.540
Epoch 14 Train Loss 0.189 Test Accuracy 94.680
Epoch 15 Train Loss 0.182 Test Accuracy 94.910
Epoch 16 Train Loss 0.175 Test Accuracy 95.070
Epoch 17 Train Loss 0.169 Test Accuracy 95.190
Epoch 18 Train Loss 0.163 Test Accuracy 95.280
Epoch 19 Train Loss 0.157 Test Accuracy 95.410
Epoch 20 Train Loss 0.153 Test Accuracy 95.570
Epoch 21 Train Loss 0.148 Test Accuracy 95.760
Epoch 22 Train Loss 0.143 Test Accuracy 95.760
Epoch 23 Train Loss 0.139 Test Accuracy 95.930
Epoch 24 Train Loss 0.135 Test Accuracy 95.910
Epoch 25 Train Loss 0.131 Test Accuracy 96.150
Epoch 26 Train Loss 0.128 Test Accuracy 96.110
Epoch 27 Train Loss 0.124 Test Accuracy 96.310
Epoch 28 Train Loss 0.121 Test Accuracy 96.330
Epoch 29 Train Loss 0.118 Test Accuracy 96.410
Epoch 30 Train Loss 0.115 Test Accuracy 96.410
Creat Custom Layers
To implement a custom layer in BrainPy, you will have to write a Python class that subclasses brainpy.simulation.layers.Module
and implement at least one method: update()
. This method computes the output of the module given its input.
import brainpy as bp
bp.set_platform('cpu')
import brainpy.simulation.layers as nn
import brainpy.math.jax as bm
bp.math.use_backend('jax')
The following is an example implementation of a layer that multiplies its input by 2:
class DoubleLayer(nn.Module):
def update(self, x):
return 2 * x
This is all that’s required to implement a functioning custom module class in BrainPy.
A layer with parameters
If the layer has parameters, these should be initialized in the constructor. In BrainPy, we recommend you to mark parameters as brainpy.math.TrainVar.
To show how this can be used, here is a layer that multiplies its input by a matrix W
(much like a typical fully connected layer in a neural network would). This matrix is a parameter of the layer. The shape of the matrix will be (num_input, num_hidden), where num_input is the number of input features and num_hidden has to be specified when the layer is created.
class DotLayer(nn.Module):
def __init__(self, num_input, num_hidden, W=bp.initialize.Normal(), **kwargs):
super(DotLayer, self).__init__(**kwargs)
self.num_input = num_input
self.num_hidden = num_hidden
self.W = bm.TrainVar(W([num_input, num_hidden]))
def update(self, x):
return bm.dot(x, self.W)
A few things are worth noting here: when overriding the constructor, we need to call the superclass constructor on the first line. This is important to ensure the layer functions properly. Note that we pass **kwargs
- although this is not strictly necessary, it enables some other cool features, such as making it possible to give the layer a name:
l_dot = DotLayer(10, 50, name='my_dot_layer')
A layer with multiple behaviors
Some layers can have multiple behaviors. For example, a layer implementing dropout should be able to be switched on or off. During training, we want it to apply dropout noise to its input and scale up the remaining values, but during evaluation we don’t want it to do anything.
For this purpose, the update()
method takes optional keyword arguments (kwargs
). When update()
is called to compute an expression for the output of a network, all specified keyword arguments are passed to the update()
methods of all layers in the network.
class Dropout(nn.Module):
def __init__(self, prob, seed=None, **kwargs):
super(Dropout, self).__init__(**kwargs)
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)
def update(self, x, **kwargs):
if kwargs.get('train', True):
keep_mask = self.rng.bernoulli(self.prob, x.shape)
return bm.where(keep_mask, x / self.prob, 0.)
else:
return x
Dynamics Analysis
Dynamics Analysis (Symbolic)
As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. For example, attractors and repellors can be easily obtained through simulation by time forward and backward, while saddles can be hard to find.
Currently, BrainPy supports two kinds of analysis methods (see brainpy.analyis documents):
Dynamics analysis with symbolic inference
The first class of analysis method supports neurodynamics analysis for low-dimensional dynamical systems. Specifically, BrainPy provides the following methods for dynamics analysis:
phase plane analysis for one-dimensional and two-dimensional systems;
codimension one and codimension two bifurcation analysis;
bifurcation analysis of the fast-slow system.
In this section, I will illustrate how to do neuron dynamics analysis in BrainPy and how BrainPy implements it.
import brainpy as bp
bp.__version__
'1.1.0'
Phase Plane Analysis
We provide a fundamental class PhasePlane
to help users make
phase plane analysis for 1D/2D dynamical systems. Five methods
are provided, which can help you to plot:
Fixed points
Nullcline (zero-growth isoclines)
Vector filed
Limit cycles
Trajectory
Here, I will illustrate how to do phase plane analysis by using a well-known neuron model FitzHugh-Nagumo model.
FitzHugh-Nagumo model
The FitzHugh-Nagumo model is given by:
$$ \frac {dV} {dt} = V(1 - \frac {V^2} 3) - w + I_{ext} \ \tau \frac {dw} {dt} = V + a - b w $$
There are two variables $V$ and $w$, so this is a two-dimensional system with three parameters $a, b$ and $\tau$.
a = 0.7
b = 0.8
tau = 12.5
Vth = 1.9
@bp.odeint
def int_fhn(V, w, t, Iext):
dw = (V + a - b * w) / tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
Phase Plane Analysis is implemented in brainpy.sym_analysis.PhasePlane
. It receives the following parameters:
integrals
: The integral functions or instance ofbrainpy.DynamicalSystem
are going to be analyzed.target_vars
: The variables to be analuzed. It must a dictionary with the format of{var: variable range}
.fixed_vars
: The variables to be fixed (optional).pars_update
: Parameters to update (optional).
brainpy.analysis.PhasePlane
provides interface to analyze the system’s
nullcline: The zero-growth isoclines, such as $g(x, y)=0$ and $g(x, y)=0$.
fixed points: The equilibrium points of the system, which are located at all of the nullclines intersect.
vector filed: The vector field of the system.
Trajectory: A given simulation trajectory with the fixed variables.
Limit cycles: The limit cycles.
Here we perform a phase plane analysis with parameters $a=0.7, b=0.8, \tau=12.5$, and input $I_{ext} = 0.8$.
analyzer = bp.symbolic.PhasePlane(
int_fhn,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
pars_update={'Iext': 0.8})
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory([{'V': -2.8, 'w': -1.8}],
duration=100.,
show=True)

We can see an unstable-node at the point (v=-0.27, w=0.53) inside a limit cycle. Then we can run a simulation with the same parameters and initial values to see the periodic activity that correspond to the limit cycle.
class FHN(bp.NeuGroup):
def __init__(self, num, **kwargs):
super(FHN, self).__init__(size=num, **kwargs)
self.V = bp.math.Variable(bp.math.ones(num) * -2.8)
self.w = bp.math.Variable(bp.math.ones(num) * -1.8)
self.Iext = bp.math.Variable(bp.math.zeros(num))
def update(self, _t, _dt):
self.V[:], self.w[:] = int_fhn(self.V, self.w, _t, self.Iext)
self.Iext[:] = 0.
group = FHN(1, monitors=['V', 'w'])
group.run(100., inputs=('Iext', 0.8))
bp.visualize.line_plot(group.mon.ts, group.mon.V, legend='v', )
bp.visualize.line_plot(group.mon.ts, group.mon.w, legend='w', show=True)

Note that the fixed_vars
can be used to specify the neuron model’s state ST
, it can also be used to specify the functional arguments in integrators (like the Iext
in int_v()
).
Bifurcation Analysis
Bifurcation analysis is implemented within brainpy.sym_analysis.Bifurcation
. Which support codimension-1 and codimension-2 bifurcation analysis. Specifically, it receives the following parameter settings:
integrals
: The integral functions or instance ofbrainpy.DynamicalSystem
are going to be analyzed.target_pars
: The target parameters. Must be a dictionary with the format of{par: parameter range}
.target_vars
: The target variables. Must be a dictionary with the format of{var: variable range}
.fixed_vars
: The fixed variables.pars_update
: The parameters to update.
Codimension 1 bifurcation analysis
We will first see the codimension 1 bifurcation anlysis of the model. For example, we vary the input $I_{ext}$ between 0 to 1 and see how the system change it’s stability.
analyzer = bp.symbolic.Bifurcation(
int_fhn,
target_pars={'Iext': [0., 1.]},
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
numerical_resolution=0.001,
)
res = analyzer.plot_bifurcation(show=True)


Codimension 2 bifurcation analysis
We simulaneously change $I_{ext}$ and parameter $a$.
analyzer = bp.symbolic.Bifurcation(
int_fhn,
target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
target_vars=dict(V=[-3, 3], w=[-3., 3.]),
numerical_resolution=0.01,
)
res = analyzer.plot_bifurcation(show=True)


Fast-Slow System Bifurcation
BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.sym_analysis.FastSlowBifurcation
. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.
brainpy.sym_analysis.FastSlowBifurcation
is very usefull in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:
$$ \begin{aligned} \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I\ \frac{d y}{d t} &= c - d V^2 - y\ \frac{d z}{d t} &= r (s (V - V_{rest}) - z) \end{aligned} $$
First of all, let’s define the Hindmarsh–Rose model with BrainPy.
a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9
@bp.odeint(method='rk4', dt=0.02)
def int_hr(x, y, z, t, Isyn):
dx = y - a * x ** 3 + b * x * x - z + Isyn
dy = c - d * x * x - y
dz = r * (s * (x - x_r) - z)
return dx, dy, dz
We now can start to analysis the underlying bifurcation mechanism.
analyzer = bp.symbolic.FastSlowBifurcation(
int_hr,
fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
slow_vars={'z': [-5., 5.]},
pars_update={'Isyn': 0.5},
numerical_resolution=0.001
)
analyzer.plot_bifurcation()
analyzer.plot_trajectory([{'x': 1., 'y': 0., 'z': -0.0}],
duration=100.,
show=True)


References:
[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.
[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.
[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.
Dynamics Analysis (Numeric)
brainpy.base
module
The base
module for whole BrainPy ecosystem.
This module provides the most fundamental class
Base
, and its associated helper classCollector
andArrayCollector
.For each instance of “Base” class, users can retrieve all the variables (or trainable variables), integrators, and nodes.
This module also provides a
Function
class to wrap user-defined functions. In each function, maybe several nodes are used, and users can initialize aFunction
by providing the nodes used in the function. Unfortunately,Function
class does not have the ability to gather nodes automatically.This module provides
io
helper functions to help users save/load model states, or share user’s customized model with others.
Details please see the following.
Base Class
|
The Base class for whole BrainPy ecosystem. |
- class brainpy.base.Base(name=None)[source]
The Base class for whole BrainPy ecosystem.
The subclass of Base includes:
DynamicalSystem
in brainpy.simulation.brainobjects.base.pyFunction
in brainpy.base.function.pyAutoGrad
in brainpy.math.jax.autograd.pyOptimizer
in brainpy.math.jax.optimizers.pyScheduler
in brainpy.math.jax.optimizers.py
- implicit_nodes = None
Used to wrap the implicit children nodes which cannot be accessed by self.xxx
- implicit_vars = None
Used to wrap the implicit variables which cannot be accessed by self.xxx
- ints(method='absolute')[source]
Collect all integrators in this node and the children nodes.
- Parameters
method (str) – The method to access the integrators.
- Returns
collector – The collection contained (the path, the integrator).
- Return type
- load_states(filename, verbose=False, check=False)[source]
Load the model states.
- Parameters
filename (str) – The filename which stores the model states.
- nodes(method='absolute', _paths=None)[source]
Collect all children nodes.
- Parameters
method (str) – The method to access the nodes.
_paths (set, Optional) – The data structure to solve the circular reference.
- Returns
gather – The collection contained (the path, the node).
- Return type
- save_states(filename, all_vars=None, **setting)[source]
Save the model states.
- Parameters
filename (str) – The file name which to store the model states.
- target_backend = None
Used to specify the target backend which the model to run.
- train_vars(method='absolute')[source]
The shortcut for retrieving all trainable variables.
- Parameters
method (str) – The method to access the variables. Support ‘absolute’ and ‘relative’.
- Returns
gather – The collection contained (the path, the trainable variable).
- Return type
- unique_name(name=None, type=None)[source]
Get the unique name for this object.
- Parameters
name (str, optional) – The expected name. If None, the default unique name will be returned. Otherwise, the provided name will be checked to guarantee its uniqueness.
type (str, optional) – The type of this class, used for object naming.
- Returns
name – The unique name for this object.
- Return type
str
Function Wrapper
|
The wrapper for Python functions. |
- class brainpy.base.Function(f, nodes=None, dyn_vars=None, name=None)[source]
The wrapper for Python functions.
- Parameters
f (function) – The function to wrap.
nodes (optional, Base, sequence of Base, dict) – The nodes in the defined function
f
.dyn_vars (optional, ndarray, sequence of ndarray, dict) – The dynamically changed variables.
name (optional, str) – The function name.
Collectors
A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
|
A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
- class brainpy.base.Collector[source]
A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A Collector is ordered by insertion order. It is the object returned by Base.vars() and used as input in many Collector instance: optimizers, jit, etc…
- subset(var_type, judge_func=None)[source]
Get the subset of the (key, value) pair.
subset()
can be used to get a subset of some class:>>> import brainpy as bp >>> >>> some_collector = Collector() >>> >>> # get all trainable variables >>> some_collector.subset(bp.math.TrainVar) >>> >>> # get all JaxArray >>> some_collector.subset(bp.math.Variable)
or, it can be used to get a subset of integrators:
>>> # get all ODE integrators >>> some_collector.subset(bp.ode.ODEIntegrator)
- Parameters
var_type (Any) – The type/class to match.
judge_func (optional, callable) –
- class brainpy.base.TensorCollector[source]
A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A Collection is ordered by insertion order. It is the object returned by DynamicalSystem.vars() and used as input in many DynamicalSystem instance: optimizers, Jit, etc…
Exporting and Loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
brainpy.math
module
The math
module for whole BrainPy ecosystem.
This module provides basic mathematical operations, including:
numpy-like array operations
linear algebra functions
random sampling functions
discrete fourier transform functions
compilations of
jit
,vmap
,pmap
for class objectsautomatic differentiation of
grad
,jacocian
,hessian
, etc. for class objectsloss functions
activation functions
optimization classes
Details in the following.
General Functions
The math
module for whole BrainPy ecosystem.
This module provides basic mathematical operations, including:
numpy-like array operations
linear algebra functions
random sampling functions
discrete fourier transform functions
compilations of
jit
,vmap
,pmap
for class objectsautomatic differentiation of
grad
,jacocian
,hessian
, etc. for class objectsloss functions
activation functions
optimization classes
Details in the following.
|
|
Get the current backend name. |
|
|
Set the numerical integrator precision. |
|
Get the numerical integrator precision. |
|
Set the default |
|
Set the default |
|
Set the default |
JAX backend Supports
Compilations
|
JIT (Just-In-Time) Compilation for JAX backend. |
Variables
|
The pointer to specify the dynamical variable. |
|
The pointer to specify the trainable variable. |
|
The pointer to specify the parameter. |
Functions
|
NumPy backend Supports
Compilations
|
Just-In-Time (JIT) Compilation in NumPy backend. |
Variables
|
Variable. |
|
Trainable Variable. |
|
Parameter. |
Functions
|
JAX Special Supports
Parallel Compilation
The parallel compilation tools for JAX backend.
Vectorize compilation is implemented by the ‘vmap()’ function
Parallel compilation is implemented by the ‘pmap()’ function
|
Vectorization compilation in JAX backend. |
|
Parallel compilation in JAX backend. |
Operators
|
|
|
|
|
Computes the sum within segments of an array. |
|
Computes the product within segments of an array. |
|
Computes the product within segments of an array. |
|
Computes the product within segments of an array. |
Control Flows
|
Make a for-loop function, which iterate over inputs. |
|
Make a while-loop function. |
|
Make a condition (if-else) function. |
Automatic Differentiation
|
Automatic Gradient Computation in JAX backend. |
|
Jacobian of |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Compute the gradients of trainable variables for the given object. |
|
Base Class to Compute Jacobian Matrix. |
- class brainpy.math.jax.autograd.Grad(fun, grad_vars, grad_tree, vars, argnums=None, has_aux=None, holomorphic=False, allow_int=False, reduce_axes=(), return_value=False, name=None)[source]
Compute the gradients of trainable variables for the given object.
Examples
This example is that we return two auxiliary data, i.e.,
has_aux=True
.>>> import brainpy as bp >>> import brainpy.math as bm >>> >>> class Test(bp.Base): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bm.TrainVar(bp.math.ones(1)) >>> self.b = bm.TrainVar(bp.math.ones(1)) >>> >>> def __call__(self, c): >>> ab = self.a * self.b >>> ab2 = ab * 2 >>> vv = ab2 + c >>> return vv, (ab, ab2) >>> >>> test = Test() >>> test_grad = Grad(test, test.vars(), argnums=0, has_aux=True) >>> grads, outputs = test_grad(10.) >>> grads (DeviceArray(1., dtype=float32), {'Test3.a': DeviceArray([2.], dtype=float32), 'Test3.b': DeviceArray([2.], dtype=float32)}) >>> outputs (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32)))
This example is that we return two auxiliary data, i.e.,
has_aux=True
.>>> import brainpy as bp >>> >>> class Test(bp.dnn.Module): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bp.TrainVar(bp.math.ones(1)) >>> self.b = bp.TrainVar(bp.math.ones(1)) >>> >>> def __call__(self, c): >>> ab = self.a * self.b >>> ab2 = ab * 2 >>> vv = ab2 + c >>> return vv, (ab, ab2) >>> >>> test = Test() >>> test_grad = ValueAndGrad(test, argnums=0, has_aux=True) >>> outputs, grads = test_grad(10.) >>> grads (DeviceArray(1., dtype=float32), {'Test3.a': DeviceArray([2.], dtype=float32), 'Test3.b': DeviceArray([2.], dtype=float32)}) >>> outputs (JaxArray(DeviceArray(12., dtype=float32)), (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32))))
Activation Functions
This module provides commonly used activation functions.
Activation functions are a critical part of the design of a neural network. The choice of activation function in the hidden layer will control how well the network model learns the training dataset. The choice of activation function in the output layer will define the type of predictions the model can make.
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU activation function |
|
Hard SiLU activation function |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
One-hot encodes the given indicies. |
|
Normalizes an array by subtracting mean and dividing by sqrt(var). |
|
|
|
Rectified Linear Unit 6 activation function. |
|
Sigmoid activation function. |
|
Soft-sign activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
SiLU activation function. |
|
SiLU activation function. |
|
Scaled exponential linear unit activation. |
Loss Functions
This module implements many commonly used loss functions.
The references used are included:
https://github.com/deepmind/optax/blob/master/optax/_src/loss.py
https://github.com/google/jaxopt/blob/main/jaxopt/_src/loss.py
|
This criterion combines |
|
Creates a criterion that measures the mean absolute error (MAE) between each element in |
|
Computes the L2 loss. |
|
Computes the L2 loss. |
|
Huber loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
Optimizers
|
|
|
Base Optimizer Class. |
|
Stochastic gradient descent optimizer. |
|
Momentum optimizer. |
|
Nesterov accelerated gradient optimizer 2. |
|
Optimizer that implements the Adagrad algorithm. |
|
Optimizer that implements the Adadelta algorithm. |
|
Optimizer that implements the RMSprop algorithm. |
|
Optimizer that implements the Adam algorithm. |
|
The learning rate scheduler. |
|
|
|
|
|
|
|
|
|
- class brainpy.math.jax.optimizers.Optimizer(train_vars, lr, name)[source]
Base Optimizer Class.
- target_backend = 'jax'
Used to specify the target backend which the model to run.
- class brainpy.math.jax.optimizers.SGD(lr, train_vars, name=None)[source]
Stochastic gradient descent optimizer.
SGD performs a parameter update for training examples \(x\) and label \(y\):
\[\theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)\]
- class brainpy.math.jax.optimizers.Momentum(lr, train_vars, momentum=0.9, name=None)[source]
Momentum optimizer.
Momentum 1 is a method that helps accelerate SGD in the relevant direction and dampens oscillations. It does this by adding a fraction \(\gamma\) of the update vector of the past time step to the current update vector:
\[egin{align} egin{split} v_t &= \gamma v_{t-1} + \eta\]- abla_ heta J( heta)
heta &= heta - v_t
end{split} end{align}
- 1
Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
- class brainpy.math.jax.optimizers.MomentumNesterov(lr, train_vars, momentum=0.9, name=None)[source]
Nesterov accelerated gradient optimizer 2.
\[egin{align} egin{split} v_t &= \gamma v_{t-1} + \eta\]
- class brainpy.math.jax.optimizers.Adagrad(lr, train_vars, epsilon=1e-06, name=None)[source]
Optimizer that implements the Adagrad algorithm.
Adagrad 3 is an optimizer with parameter-specific learning rates, which are adapted relative to how frequently a parameter gets updated during training. The more updates a parameter receives, the smaller the updates.
\[heta_{t+1} = heta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}\]where \(G(t)\) contains the sum of the squares of the past gradients
One of Adagrad’s main benefits is that it eliminates the need to manually tune the learning rate. Most implementations use a default value of 0.01 and leave it at that. Adagrad’s main weakness is its accumulation of the squared gradients in the denominator: Since every added term is positive, the accumulated sum keeps growing during training. This in turn causes the learning rate to shrink and eventually become infinitesimally small, at which point the algorithm is no longer able to acquire additional knowledge.
References
- 3
Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
- class brainpy.math.jax.optimizers.Adadelta(train_vars, lr=0.01, epsilon=1e-06, rho=0.95, name=None)[source]
Optimizer that implements the Adadelta algorithm.
Adadelta 4 optimization is a stochastic gradient descent method that is based on adaptive learning rate per dimension to address two drawbacks:
The continual decay of learning rates throughout training.
The need for a manually selected global learning rate.
Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving window of gradient updates, instead of accumulating all past gradients. This way, Adadelta continues learning even when many updates have been done. Compared to Adagrad, in the original version of Adadelta you don’t have to set an initial learning rate.
\[oldsymbol{s}_t \leftarrow\]ho oldsymbol{s}_{t-1} + (1 - ho) oldsymbol{g}_t odot oldsymbol{g}_t,
oldsymbol{g}_t’ leftarrow sqrt{
- rac{Deltaoldsymbol{x}_{t-1} + epsilon}{oldsymbol{s}_t + epsilon}} odot oldsymbol{g}_t,
oldsymbol{x}_t leftarrow oldsymbol{x}_{t-1} - oldsymbol{g}’_t, Deltaoldsymbol{x}_t leftarrow
ho Deltaoldsymbol{x}_{t-1} + (1 - ho) oldsymbol{g}’_t odot oldsymbol{g}’_t.
:math:`
- ho` should be between 0 and 1. A value of rho close to 1 will decay the
moving average slowly and a value close to 0 will decay the moving average fast.
:math:`
- ho` = 0.95 and :math:`epsilon`=1e-6 are suggested in the paper and reported
to work for multiple datasets (MNIST, speech).
In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to keep it at this value. epsilon is important for the very first update (so the numerator does not become 0).
- 4
Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
- class brainpy.math.jax.optimizers.RMSProp(lr, train_vars, epsilon=1e-06, rho=0.9, name=None)[source]
Optimizer that implements the RMSprop algorithm.
RMSprop 5 and Adadelta have both been developed independently around the same time stemming from the need to resolve Adagrad’s radically diminishing learning rates.
The gist of RMSprop is to:
Maintain a moving (discounted) average of the square of gradients
Divide the gradient by the root of this average
\[egin{split}c_t &=\]ho c_{t-1} + (1- ho)*g^2
p_t &=
rac{eta}{sqrt{c_t + epsilon}} * g end{split}
The centered version additionally maintains a moving average of the gradients, and uses that average to estimate the variance.
- 5
Tieleman, T. and Hinton, G. (2012): Neural Networks for Machine Learning, Lecture 6.5 - rmsprop. Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
- class brainpy.math.jax.optimizers.Adam(lr, train_vars, beta1=0.9, beta2=0.999, eps=1e-08, name=None)[source]
Optimizer that implements the Adam algorithm.
Adam 6 - a stochastic gradient descent method (SGD) that computes individual adaptive learning rates for different parameters from estimates of first- and second-order moments of the gradients.
- Parameters
beta1 (optional, float) – A positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
beta2 (optional, float) – A positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
eps (optional, float) – A positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).
name (optional, str) – The optimizer name.
References
- 6
Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
Comparison Table
Here is a list of NumPy APIs and its corresponding BrainPy implementations.
-
in BrainPy column denotes that implementation is not provided yet.
We welcome contributions for these functions.
Multi-dimensional Array
NumPy |
brainpy.math.numpy |
brainpy.math.jax |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
- |
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
Summary
Number of NumPy functions: 56
Number of functions covered by
brainpy.math.numpy
: 56Number of functions covered by
brainpy.math.jax
: 44
Array Operations
NumPy |
brainpy.math.numpy |
brainpy.math.jax |
---|---|---|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
- |
|
|
- |
||
- |
|
|
- |
|
|
- |
|
|
Summary
Number of NumPy functions: 401
Number of functions covered by
brainpy.math.numpy
: 225Number of functions covered by
brainpy.math.jax
: 225
Linear Algebra
NumPy |
brainpy.math.numpy |
brainpy.math.jax |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
Summary
Number of NumPy functions: 20
Number of functions covered by
brainpy.math.numpy
: 15Number of functions covered by
brainpy.math.jax
: 15
Discrete Fourier Transform
NumPy |
brainpy.math.numpy |
brainpy.math.jax |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Summary
Number of NumPy functions: 18
Number of functions covered by
brainpy.math.numpy
: 18Number of functions covered by
brainpy.math.jax
: 18
Random Sampling
NumPy |
brainpy.math.numpy |
brainpy.math.jax |
---|---|---|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
- |
|
|
- |
- |
|
- |
|
- |
- |
|
|
Summary
Number of NumPy functions: 51
Number of functions covered by
brainpy.math.numpy
: 26Number of functions covered by
brainpy.math.jax
: 26
brainpy.integrators
module
This module provides numerical solvers for various differential equations, including:
ordinary differential equations (ODEs)
stochastic differential equations (SDEs)
Details please see the following.
General Functions
This module provides numerical solvers for various differential equations, including:
ordinary differential equations (ODEs)
stochastic differential equations (SDEs)
Details please see the following.
|
Numerical integration for ODEs. |
|
Numerical integration for SDEs. |
|
Set the default ODE numerical integrator method for differential equations. |
Get the default ODE numerical integrator method. |
|
|
Set the default SDE numerical integrator method for differential equations. |
Get the default SDE numerical integrator method. |
|
|
ODE Integrator. |
|
SDE Integrator. |
Numerical Methods for ODEs
Numerical methods for ordinary differential equations.
Explicit Runge-Kutta Methods
This module provides explicit Runge-Kutta methods for ODEs.
Given an initial value problem specified as:
Let the step-size \(h > 0\).
Then, the general schema of explicit Runge–Kutta methods is [1]_:
where
To specify a particular method, one needs to provide the integer \(s\) (the number of stages), and the coefficients \(a_{ij}\) (for \(1 \le j < i \le s\)), \(b_i\) (for \(i = 1, 2, \cdots, s\)) and \(c_i\) (for \(i = 2, 3, \cdots, s\)).
The matrix \([a_{ij}]\) is called the Runge–Kutta matrix, while the \(b_i\) and \(c_i\) are known as the weights and the nodes. These data are usually arranged in a mnemonic device, known as a Butcher tableau (named after John C. Butcher):
A Taylor series expansion shows that the Runge–Kutta method is consistent if and only if
Another popular condition for determining coefficients is:
More details please see references [2]_ 3 4.
- 1
Press, W. H., B. P. Flannery, S. A. Teukolsky, and W. T. Vetterling. “Section 17.1 Runge-Kutta Method.” Numerical Recipes: The Art of Scientific Computing (2007).
- 2
- 3
Butcher, John Charles. Numerical methods for ordinary differential equations. John Wiley & Sons, 2016.
- 4
Iserles, A., 2009. A first course in the numerical analysis of differential equations (No. 44). Cambridge university press.
|
Explicit Runge–Kutta methods for ordinary differential equation. |
|
The Euler method for ODEs. |
|
Explicit midpoint method for ODEs. |
|
Heun's method for ODEs. |
|
Ralston's method for ODEs. |
|
Generic second order Runge-Kutta method for ODEs. |
|
Classical third-order Runge-Kutta method for ODEs. |
|
Heun's third-order method for ODEs. |
|
Ralston's third-order method for ODEs. |
|
Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). |
|
Classical fourth-order Runge-Kutta method for ODEs. |
|
Ralston's fourth-order method for ODEs. |
|
3/8-rule fourth-order method for ODEs. |
- class brainpy.integrators.ode.explicit_rk.ExplicitRKIntegrator(f, var_type=None, dt=None, name=None, show_code=False)[source]
Explicit Runge–Kutta methods for ordinary differential equation.
For the system,
\[\frac{d y}{d t}=f(t, y)\]Explicit Runge-Kutta methods take the form
\[\begin{split}k_{i}=f\left(t_{n}+c_{i}h,y_{n}+h\sum _{j=1}^{s}a_{ij}k_{j}\right) \\ y_{n+1}=y_{n}+h \sum_{i=1}^{s} b_{i} k_{i}\end{split}\]Each method listed on this page is defined by its Butcher tableau, which puts the coefficients of the method in a table as follows:
\[\begin{split}\begin{array}{c|cccc} c_{1} & a_{11} & a_{12} & \ldots & a_{1 s} \\ c_{2} & a_{21} & a_{22} & \ldots & a_{2 s} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ c_{s} & a_{s 1} & a_{s 2} & \ldots & a_{s s} \\ \hline & b_{1} & b_{2} & \ldots & b_{s} \end{array}\end{split}\]- Parameters
f (callable) – The derivative function.
show_code (bool) – Whether show the formatted code.
dt (float) – The numerical precision.
- class brainpy.integrators.ode.explicit_rk.Euler(f, var_type=None, dt=None, name=None, show_code=False)[source]
The Euler method for ODEs.
Also named as Forward Euler method, or Explicit Euler method.
Given an ODE system,
\[y'(t)=f(t,y(t)),\qquad y(t_{0})=y_{0},\]by using Euler method [1]_, we should choose a value \(h\) for the size of every step and set \(t_{n}=t_{0}+nh\). Now, one step of the Euler method from \(t_{n}\) to \(t_{n+1}=t_{n}+h\) is:
\[y_{n+1}=y_{n}+hf(t_{n},y_{n}).\]Note that the method increments a solution through an interval \(h\) while using derivative information from only the beginning of the interval. As a result, the step’s error is \(O(h^2)\).
Geometric interpretation
Illustration of the Euler method. The unknown curve is in blue, and its polygonal approximation is in red [2]_:
Derivation
There are several ways to get Euler method [2]_.
The first is to consider the Taylor expansion of the function \(y\) around \(t_{0}\):
\[y(t_{0}+h)=y(t_{0})+hy'(t_{0})+{\frac {1}{2}}h^{2}y''(t_{0})+O(h^{3}).\]where \(y'(t_0)=f(t_0,y)\). We ignore the quadratic and higher-order terms, then we get Euler method. The Taylor expansion is used below to analyze the error committed by the Euler method, and it can be extended to produce Runge–Kutta methods.
The second way is to replace the derivative with the forward finite difference formula:
\[y'(t_{0})\approx {\frac {y(t_{0}+h)-y(t_{0})}{h}}.\]The third method is integrate the differential equation from \(t_{0}\) to \(t_{0}+h\) and apply the fundamental theorem of calculus to get:
\[y(t_{0}+h)-y(t_{0})=\int _{t_{0}}^{t_{0}+h}f(t,y(t))\,\mathrm {d} t \approx hf(t_{0},y(t_{0})).\]Note
Euler method is a first order numerical procedure for solving ODEs with a given initial value. The lack of stability and accuracy limits its popularity mainly to use as a simple introductory example of a numeric solution method.
References
- 1
W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. Numerical Recipes in FORTRAN: The Art of Scientific Computing, 2nd ed. Cambridge, England: Cambridge University Press, p. 710, 1992.
- 2
- class brainpy.integrators.ode.explicit_rk.MidPoint(f, var_type=None, dt=None, name=None, show_code=False)[source]
Explicit midpoint method for ODEs.
Also known as the modified Euler method [1]_.
The midpoint method is a one-step method for numerically solving the differential equation given by:
\[y'(t) = f(t, y(t)), \quad y(t_0) = y_0 .\]The formula of the explicit midpoint method is:
\[y_{n+1} = y_n + hf\left(t_n+\frac{h}{2},y_n+\frac{h}{2}f(t_n, y_n)\right).\]Therefore, the Butcher tableau of the midpoint method is:
\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 \\ \hline & 0 & 1 \end{array}\end{split}\]Derivation
Compared to the slope formula of Euler method \(y'(t) \approx \frac{y(t+h) - y(t)}{h}\), the midpoint method use
\[y'\left(t+\frac{h}{2}\right) \approx \frac{y(t+h) - y(t)}{h},\]The reason why we use this, please see the following geometric interpretation. Then, we get
\[y(t+h) \approx y(t) + hf\left(t+\frac{h}{2},y\left(t+\frac{h}{2}\right)\right).\]However, we do not know \(y(t+h/2)\). The solution is then to use a Taylor series expansion exactly as the Euler method to solve:
\[y\left(t + \frac{h}{2}\right) \approx y(t) + \frac{h}{2}y'(t)=y(t) + \frac{h}{2}f(t, y(t)),\]Finally, we can get the final step function:
\[y(t + h) \approx y(t) + hf\left(t + \frac{h}{2}, y(t) + \frac{h}{2}f(t, y(t))\right).\]Geometric interpretation
In the basic Euler’s method, the tangent of the curve at \((t_{n},y_{n})\) is computed using \(f(t_{n},y_{n})\). The next value \(y_{n+1}\) is found where the tangent intersects the vertical line \(t=t_{n+1}\). However, if the second derivative is only positive between \(t_{n}\) and \(t_{n+1}\), or only negative, the curve will increasingly veer away from the tangent, leading to larger errors as \(h\) increases.
Compared with the Euler method, midpoint method use the tangent at the midpoint (upper, green line segment in the following figure [2]_), which would most likely give a more accurate approximation of the curve in that interval.
Although this midpoint tangent could not be accurately calculated, we can estimate midpoint value of \(y(t)\) by using the original Euler’s method. Finally, the improved tangent is used to calculate the value of \(y_{n+1}\) from \(y_{n}\). This last step is represented by the red chord in the diagram.
Note
Note that the red chord is not exactly parallel to the green segment (the true tangent), due to the error in estimating the value of \(y(t)\) at the midpoint.
References
- 1
Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003.
- 2
- class brainpy.integrators.ode.explicit_rk.Heun2(f, var_type=None, dt=None, name=None, show_code=False)[source]
Heun’s method for ODEs.
This method is named after Karl Heun [1]_. It is also known as the explicit trapezoid rule, improved Euler’s method, or modified Euler’s method.
Given ODEs with a given initial value,
\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]the two-stage Heun’s method is formulated as:
\[\tilde{y}_{n+1} = y_n + h f(t_n,y_n)\]\[y_{n+1} = y_n + \frac{h}{2}[f(t_n, y_n) + f(t_{n+1},\tilde{y}_{n+1})],\]where \(h\) is the step size and \(t_{n+1}=t_n+h\).
Therefore, the Butcher tableau of the two-stage Heun’s method is:
\[\begin{split}\begin{array}{c|cc} 0.0 & 0.0 & 0.0 \\ 1.0 & 1.0 & 0.0 \\ \hline & 0.5 & 0.5 \end{array}\end{split}\]Geometric interpretation
In the
brainpy.integrators.ode.midpoint()
, we have already known Euler method has big estimation error because it uses the line tangent to the function at the beginning of the interval \(t_n\) as an estimate of the slope of the function over the interval \((t_n, t_{n+1})\).In order to address this problem, Heun’s Method considers the tangent lines to the solution curve at both ends of the interval (\(t_n\) and \(t_{n+1}\)), one (\(f(t_n, y_n)\)) which underestimates, and one (\(f(t_{n+1},\tilde{y}_{n+1})\), approximated using Euler’s Method) which overestimates the ideal vertical coordinates. The ideal point lies approximately halfway between the erroneous overestimation and underestimation, the average of the two slopes.
\[\begin{split}\begin{aligned} {\text{Slope}}_{\text{left}}=&f(t_{n},y_{n}) \\ {\text{Slope}}_{\text{right}}=&f(t_{n}+h,y_{n}+hf(t_{n},y_{n})) \\ {\text{Slope}}_{\text{ideal}}=&{\frac {1}{2}}({\text{Slope}}_{\text{left}}+{\text{Slope}}_{\text{right}}) \end{aligned}\end{split}\]References
- 1
Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003.
- class brainpy.integrators.ode.explicit_rk.Ralston2(f, var_type=None, dt=None, name=None, show_code=False)[source]
Ralston’s method for ODEs.
Ralston’s method is a second-order method with two stages and a minimum local error bound.
Given ODEs with a given initial value,
\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]the Ralston’s second order method is given by
\[y_{n+1}=y_{n}+\frac{h}{4} f\left(t_{n}, y_{n}\right)+ \frac{3 h}{4} f\left(t_{n}+\frac{2 h}{3}, y_{n}+\frac{2 h}{3} f\left(t_{n}, y_{n}\right)\right)\]Therefore, the corresponding Butcher tableau is:
\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ 2 / 3 & 2 / 3 & 0 \\ \hline & 1 / 4 & 3 / 4 \end{array}\end{split}\]
- class brainpy.integrators.ode.explicit_rk.RK2(f, beta=0.6666666666666666, var_type=None, dt=None, name=None, show_code=False)[source]
Generic second order Runge-Kutta method for ODEs.
Derivation
In the
brainpy.integrators.ode.midpoint()
,brainpy.integrators.ode.heun2()
, andbrainpy.integrators.ode.ralston2()
, we have already known first-order Euler methodbrainpy.integrators.ode.euler()
has big estimation error.Here, we seek to derive a generic second order Runge-Kutta method [1]_ for the given ODE system with a given initial value,
\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]we want to get a generic solution:
\[\begin{align} y_{n+1} &= y_{n} + h \left ( a_1 K_1 + a_2 K_2 \right ) \tag{1} \end{align}\]where \(a_1\) and \(a_2\) are some weights to be determined, and \(K_1\) and \(K_2\) are derivatives on the form:
\[\begin{align} K_1 & = f(t_n,y_n) \qquad \text{and} \qquad K_2 = f(t_n + p_1 h,y_n + p_2 K_1 h ) \tag{2} \end{align}\]By substitution of (2) in (1) we get:
\[\begin{align} y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h f(t_n + p_1 h,y_n + p_2 K_1 h) \tag{3} \end{align}\]Now, we may find a Taylor-expansion of \(f(t_n + p_1 h, y_n + p_2 K_1 h )\)
\[\begin{split}\begin{align} f(t_n + p_1 h, y_n + p_2 K_1 h ) &= f + p_1 h f_t + p_2 K_1 h f_y + \text{h.o.t.} \nonumber \\ & = f + p_1 h f_t + p_2 h f f_y + \text{h.o.t.} \tag{4} \end{align}\end{split}\]where \(f_t \equiv \frac{\partial f}{\partial t}\) and \(f_y \equiv \frac{\partial f}{\partial y}\).
By substitution of (4) in (3) we eliminate the implicit dependency of \(y_{n+1}\)
\[\begin{split}\begin{align} y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h \left (f + p_1 h f_t + p_2 h f f_y \right ) \nonumber \\ &= y_{n} + (a_1 + a_2) h f + \left (a_2 p_1 f_t + a_2 p_2 f f_y \right) h^2 \tag{5} \end{align}\end{split}\]In the next, we try to get the second order Taylor expansion of the solution:
\[\begin{align} y(t_n+h) = y_n + h y' + \frac{h^2}{2} y'' + O(h^3) \tag{6} \end{align}\]where the second order derivative is given by
\[\begin{align} y'' = \frac{d^2 y}{dt^2} = \frac{df}{dt} = \frac{\partial{f}}{\partial{t}} \frac{dt}{dt} + \frac{\partial{f}}{\partial{y}} \frac{dy}{dt} = f_t + f f_y \tag{7} \end{align}\]Substitution of (7) into (6) yields:
\[\begin{align} y(t_n+h) = y_n + h f + \frac{h^2}{2} \left (f_t + f f_y \right ) + O(h^3) \tag{8} \end{align}\]Finally, in order to approximate (8) by using (5), we get the generic second order Runge-Kutta method, where
\[\begin{split}\begin{aligned} a_1 + a_2 = 1 \\ a_2 p_1 = \frac{1}{2} \\ a_2 p_2 = \frac{1}{2}. \end{aligned}\end{split}\]Furthermore, let \(p_1=\beta\), we get
\[\begin{split}\begin{aligned} p_1 = & \beta \\ p_2 = & \beta \\ a_2 = &\frac{1}{2\beta} \\ a_1 = &1 - \frac{1}{2\beta} . \end{aligned}\end{split}\]Therefore, the corresponding Butcher tableau is:
\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ \beta & \beta & 0 \\ \hline & 1 - {1 \over 2 * \beta} & {1 \over 2 * \beta} \end{array}\end{split}\]References
- 1
Chapra, Steven C., and Raymond P. Canale. Numerical methods for engineers. Vol. 1221. New York: Mcgraw-hill, 2011.
- class brainpy.integrators.ode.explicit_rk.RK3(f, var_type=None, dt=None, name=None, show_code=False)[source]
Classical third-order Runge-Kutta method for ODEs.
For the given initial value problem \(y'(x) = f(t,y);\, y(t_0) = y_0\), the third order Runge-Kutta method is given by:
\[y_{n+1} = y_n + 1/6 ( k_1 + 4 k_2 + k_3),\]where
\[\begin{split}k_1 = h f(t_n, y_n), \\ k_2 = h f(t_n + h / 2, y_n + k_1 / 2), \\ k_3 = h f(t_n + h, y_n - k_1 + 2 k_2 ),\end{split}\]where \(t_n = t_0 + n h.\)
Error term \(O(h^4)\), correct up to the third order term in Taylor series expansion.
The Taylor series expansion is \(y(t+h)=y(t)+\frac{k}{6}+\frac{2 k_{2}}{3}+\frac{k_{3}}{6}+O\left(h^{4}\right)\).
The corresponding Butcher tableau is:
\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 \\ 1 & -1 & 2 & 0 \\ \hline & 1 / 6 & 2 / 3 & 1 / 6 \end{array}\end{split}\]
- class brainpy.integrators.ode.explicit_rk.Heun3(f, var_type=None, dt=None, name=None, show_code=False)[source]
Heun’s third-order method for ODEs.
It has the characteristics of:
method stage = 3
method order = 3
Butcher Tables:
\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 3 & 1 / 3 & 0 & 0 \\ 2 / 3 & 0 & 2 / 3 & 0 \\ \hline & 1 / 4 & 0 & 3 / 4 \end{array}\end{split}\]
- class brainpy.integrators.ode.explicit_rk.Ralston3(f, var_type=None, dt=None, name=None, show_code=False)[source]
Ralston’s third-order method for ODEs.
It has the characteristics of:
method stage = 3
method order = 3
Butcher Tables:
\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 \\ 3 / 4 & 0 & 3 / 4 & 0 \\ \hline & 2 / 9 & 1 / 3 & 4 / 9 \end{array}\end{split}\]References
- 1
Ralston, Anthony (1962). “Runge-Kutta Methods with Minimum Error Bounds”. Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0
- class brainpy.integrators.ode.explicit_rk.SSPRK3(f, var_type=None, dt=None, name=None, show_code=False)[source]
Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).
It has the characteristics of:
method stage = 3
method order = 3
Butcher Tables:
\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 / 2 & 1 / 4 & 1 / 4 & 0 \\ \hline & 1 / 6 & 1 / 6 & 2 / 3 \end{array}\end{split}\]
- class brainpy.integrators.ode.explicit_rk.RK4(f, var_type=None, dt=None, name=None, show_code=False)[source]
Classical fourth-order Runge-Kutta method for ODEs.
For the given initial value problem of
\[{\frac {dy}{dt}}=f(t,y),\quad y(t_{0})=y_{0}.\]The fourth-order RK method is formulated as:
\[\begin{split}\begin{aligned} y_{n+1}&=y_{n}+{\frac {1}{6}}h\left(k_{1}+2k_{2}+2k_{3}+k_{4}\right),\\ t_{n+1}&=t_{n}+h\\ \end{aligned}\end{split}\]for \(n = 0, 1, 2, 3, \cdot\), using
\[\begin{split}\begin{aligned} k_{1}&=\ f(t_{n},y_{n}),\\ k_{2}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{1}}{2}}\right),\\ k_{3}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{2}}{2}}\right),\\ k_{4}&=\ f\left(t_{n}+h,y_{n}+hk_{3}\right). \end{aligned}\end{split}\]Here \(y_{n+1}\) is the RK4 approximation of \(y(t_{n+1})\), and the next value (\(y_{n+1}\)) is determined by the present value (\(y_{n}\)) plus the weighted average of four increments, where each increment is the product of the size of the interval, \(h\), and an estimated slope specified by function \(f\) on the right-hand side of the differential equation.
\(k_{1}\) is the slope at the beginning of the interval, using \(y\) (Euler’s method);
\(k_{2}\) is the slope at the midpoint of the interval, using \(y\) and \(k_{1}\);
\(k_{3}\) is again the slope at the midpoint, but now using \(y\) and \(k_{2}\);
\(k_{4}\) is the slope at the end of the interval, using \(y\) and \(k_{3}\).
The RK4 method is a fourth-order method, meaning that the local truncation error is on the order of (\(O(h^{5}\)), while the total accumulated error is on the order of (\(O(h^{4}\)).
The corresponding Butcher tableau is:
\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 & 0 \\ 1 / 2 & 0 & 1 / 2 & 0 & 0 \\ 1 & 0 & 0 & 1 & 0 \\ \hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6 \end{array}\end{split}\]References
- 1
Lambert, J. D. and Lambert, D. Ch. 5 in Numerical Methods for Ordinary Differential Systems: The Initial Value Problem. New York: Wiley, 1991.
- 2
Press, W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. “Runge-Kutta Method” and “Adaptive Step Size Control for Runge-Kutta.” §16.1 and 16.2 in Numerical Recipes in FORTRAN: The Art of Scientific Computing, 2nd ed. Cambridge, England: Cambridge University Press, pp. 704-716, 1992.
- class brainpy.integrators.ode.explicit_rk.Ralston4(f, var_type=None, dt=None, name=None, show_code=False)[source]
Ralston’s fourth-order method for ODEs.
It has the characteristics of:
method stage = 4
method order = 4
Butcher Tables:
\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ .4 & .4 & 0 & 0 & 0 \\ .45573725 & .29697761 & .15875964 & 0 & 0 \\ 1 & .21810040 & -3.05096516 & 3.83286476 & 0 \\ \hline & .17476028 & -.55148066 & 1.20553560 & .17118478 \end{array}\end{split}\]References
- 1
Ralston, Anthony (1962). “Runge-Kutta Methods with Minimum Error Bounds”. Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0
- class brainpy.integrators.ode.explicit_rk.RK4Rule38(f, var_type=None, dt=None, name=None, show_code=False)[source]
3/8-rule fourth-order method for ODEs.
A slight variation of “the” Runge–Kutta method is also due to Kutta in 1901 [1]_ and is called the 3/8-rule. The primary advantage this method has is that almost all of the error coefficients are smaller than in the popular method, but it requires slightly more FLOPs (floating-point operations) per time step.
It has the characteristics of:
method stage = 4
method order = 4
Butcher Tables:
\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ 1 / 3 & 1 / 3 & 0 & 0 & 0 \\ 2 / 3 & -1 / 3 & 1 & 0 & 0 \\ 1 & 1 & -1 & 1 & 0 \\ \hline & 1 / 8 & 3 / 8 & 3 / 8 & 1 / 8 \end{array}\end{split}\]References
- 1
Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving ordinary differential equations I: Nonstiff problems, Berlin, New York: Springer-Verlag, ISBN 978-3-540-56670-0.
Adaptive Runge-Kutta Methods
This module provides adaptive Runge-Kutta methods for ODEs.
Adaptive methods are designed to produce an estimate of the local truncation error of a single Runge–Kutta step. This is done by having two methods, one with order \(p\) and one with order \(p-1\). These methods are interwoven, i.e., they have common intermediate steps. Thanks to this, estimating the error has little or negligible computational cost compared to a step with the higher-order method.
During the integration, the step size is adapted such that the estimated error stays below a user-defined threshold: If the error is too high, a step is repeated with a lower step size; if the error is much smaller, the step size is increased to save time. This results in an (almost) optimal step size, which saves computation time. Moreover, the user does not have to spend time on finding an appropriate step size.
The lower-order step is given by
where \(k_{i}\) are the same as for the higher-order method. Then the error is
which is (\(O(h^{p}\)).
The Butcher tableau for this kind of method is extended to give the values of \(b_{i}^{*}\):
More details please check [1]_ [2]_ 3.
- 1
- 2
Press, W.H., Press, W.H., Flannery, B.P., Teukolsky, S.A., Vetterling, W.T., Flannery, B.P. and Vetterling, W.T., 1989. Numerical recipes in Pascal: the art of scientific computing (Vol. 1). Cambridge university press.
- 3
Press, W. H., & Teukolsky, S. A. (1992). Adaptive Stepsize Runge‐Kutta Integration. Computers in Physics, 6(2), 188-191.
|
Adaptive Runge-Kutta method for ordinary differential equations. |
|
The Fehlberg RK1(2) method for ODEs. |
|
The Runge–Kutta–Fehlberg method for ODEs. |
|
The Dormand–Prince method for ODEs. |
|
The Cash–Karp method for ODEs. |
|
The Bogacki–Shampine method for ODEs. |
|
The Heun–Euler method for ODEs. |
- class brainpy.integrators.ode.adaptive_rk.AdaptiveRKIntegrator(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
Adaptive Runge-Kutta method for ordinary differential equations.
The embedded methods are designed to produce an estimate of the local truncation error of a single Runge-Kutta step, and as result, allow to control the error with adaptive step-size. This is done by having two methods in the tableau, one with order p and one with order \(p-1\).
The lower-order step is given by
\[y^*_{n+1} = y_n + h\sum_{i=1}^s b^*_i k_i,\]where the \(k_{i}\) are the same as for the higher order method. Then the error is
\[e_{n+1} = y_{n+1} - y^*_{n+1} = h\sum_{i=1}^s (b_i - b^*_i) k_i,\]which is \(O(h^{p})\). The Butcher Tableau for this kind of method is extended to give the values of \(b_{i}^{*}\)
\[\begin{split}\begin{array}{c|cccc} c_1 & a_{11} & a_{12}& \dots & a_{1s}\\ c_2 & a_{21} & a_{22}& \dots & a_{2s}\\ \vdots & \vdots & \vdots& \ddots& \vdots\\ c_s & a_{s1} & a_{s2}& \dots & a_{ss} \\ \hline & b_1 & b_2 & \dots & b_s\\ & b_1^* & b_2^* & \dots & b_s^*\\ \end{array}\end{split}\]- Parameters
f (callable) – The derivative function.
show_code (bool) – Whether show the formatted code.
dt (float) – The numerical precision.
adaptive (bool) – Whether use the adaptive updating.
tol (float) – The error tolerence.
var_type (str) – The variable type.
- class brainpy.integrators.ode.adaptive_rk.RKF12(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Fehlberg RK1(2) method for ODEs.
The Fehlberg method has two methods of orders 1 and 2.
It has the characteristics of:
method stage = 2
method order = 1
Butcher Tables:
\[\begin{split}\begin{array}{l|ll} 0 & & \\ 1 / 2 & 1 / 2 & \\ 1 & 1 / 256 & 255 / 256 & \\ \hline & 1 / 512 & 255 / 256 & 1 / 512 \\ & 1 / 256 & 255 / 256 & 0 \end{array}\end{split}\]References
- 1
Fehlberg, E. (1969-07-01). “Low-order classical Runge-Kutta formulas with stepsize control and their application to some heat transfer problems”
- class brainpy.integrators.ode.adaptive_rk.RKF45(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Runge–Kutta–Fehlberg method for ODEs.
The method presented in Fehlberg’s 1969 paper has been dubbed the RKF45 method, and is a method of order \(O(h^4)\) with an error estimator of order \(O(h^5)\). The novelty of Fehlberg’s method is that it is an embedded method from the Runge–Kutta family, meaning that identical function evaluations are used in conjunction with each other to create methods of varying order and similar error constants.
Its Butcher table is:
\[\begin{split}\begin{array}{l|lllll} 0 & & & & & & \\ 1 / 4 & 1 / 4 & & & & \\ 3 / 8 & 3 / 32 & 9 / 32 & & \\ 12 / 13 & 1932 / 2197 & -7200 / 2197 & 7296 / 2197 & \\ 1 & 439 / 216 & -8 & 3680 / 513 & -845 / 4104 & & \\ 1 / 2 & -8 / 27 & 2 & -3544 / 2565 & 1859 / 4104 & -11 / 40 & \\ \hline & 16 / 135 & 0 & 6656 / 12825 & 28561 / 56430 & -9 / 50 & 2 / 55 \\ & 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0 \end{array}\end{split}\]References
- 1
https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
- 2
Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step size control and their application to some heat transfer problems . NASA Technical Report 315. https://ntrs.nasa.gov/api/citations/19690021375/downloads/19690021375.pdf
- class brainpy.integrators.ode.adaptive_rk.DormandPrince(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Dormand–Prince method for ODEs.
The DOPRI method, is an explicit method for solving ordinary differential equations (Dormand & Prince 1980). The Dormand–Prince method has seven stages, but it uses only six function evaluations per step because it has the FSAL (First Same As Last) property: the last stage is evaluated at the same point as the first stage of the next step. Dormand and Prince chose the coefficients of their method to minimize the error of the fifth-order solution. This is the main difference with the Fehlberg method, which was constructed so that the fourth-order solution has a small error. For this reason, the Dormand–Prince method is more suitable when the higher-order solution is used to continue the integration, a practice known as local extrapolation (Shampine 1986; Hairer, Nørsett & Wanner 2008, pp. 178–179).
Its Butcher table is:
\[\begin{split}\begin{array}{l|llllll} 0 & \\ 1 / 5 & 1 / 5 & & & \\ 3 / 10 & 3 / 40 & 9 / 40 & & & \\ 4 / 5 & 44 / 45 & -56 / 15 & 32 / 9 & & \\ 8 / 9 & 19372 / 6561 & -25360 / 2187 & 64448 / 6561 & -212 / 729 & \\ 1 & 9017 / 3168 & -355 / 33 & 46732 / 5247 & 49 / 176 & -5103 / 18656 & \\ 1 & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & \\ \hline & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & 0 \\ & 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40 \end{array}\end{split}\]References
- 1
- 2
Dormand, J. R.; Prince, P. J. (1980), “A family of embedded Runge-Kutta formulae”, Journal of Computational and Applied Mathematics, 6 (1): 19–26, doi:10.1016/0771-050X(80)90013-3.
- class brainpy.integrators.ode.adaptive_rk.CashKarp(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Cash–Karp method for ODEs.
The Cash–Karp method was proposed by Professor Jeff R. Cash from Imperial College London and Alan H. Karp from IBM Scientific Center. it uses six function evaluations to calculate fourth- and fifth-order accurate solutions. The difference between these solutions is then taken to be the error of the (fourth order) solution. This error estimate is very convenient for adaptive stepsize integration algorithms.
It has the characteristics of:
method stage = 6
method order = 4
Butcher Tables:
\[\begin{split}\begin{array}{l|lllll} 0 & & & & & & \\ 1 / 5 & 1 / 5 & & & & & \\ 3 / 10 & 3 / 40 & 9 / 40 & & & \\ 3 / 5 & 3 / 10 & -9 / 10 & 6 / 5 & & \\ 1 & -11 / 54 & 5 / 2 & -70 / 27 & 35 / 27 & & \\ 7 / 8 & 1631 / 55296 & 175 / 512 & 575 / 13824 & 44275 / 110592 & 253 / 4096 & \\ \hline & 37 / 378 & 0 & 250 / 621 & 125 / 594 & 0 & 512 / 1771 \\ & 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4 \end{array}\end{split}\]References
- 1
- 2
J. R. Cash, A. H. Karp. “A variable order Runge-Kutta method for initial value problems with rapidly varying right-hand sides”, ACM Transactions on Mathematical Software 16: 201-222, 1990. doi:10.1145/79505.79507
- class brainpy.integrators.ode.adaptive_rk.BogackiShampine(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Bogacki–Shampine method for ODEs.
The Bogacki–Shampine method was proposed by Przemysław Bogacki and Lawrence F. Shampine in 1989 (Bogacki & Shampine 1989). The Bogacki–Shampine method is a Runge–Kutta method of order three with four stages with the First Same As Last (FSAL) property, so that it uses approximately three function evaluations per step. It has an embedded second-order method which can be used to implement adaptive step size.
It has the characteristics of:
method stage = 4
method order = 3
Butcher Tables:
\[\begin{split}\begin{array}{l|lll} 0 & & & \\ 1 / 2 & 1 / 2 & & \\ 3 / 4 & 0 & 3 / 4 & \\ 1 & 2 / 9 & 1 / 3 & 4 / 9 \\ \hline & 2 / 9 & 1 / 3 & 4 / 90 \\ & 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8 \end{array}\end{split}\]References
- 1
https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method
- 2
Bogacki, Przemysław; Shampine, Lawrence F. (1989), “A 3(2) pair of Runge–Kutta formulas”, Applied Mathematics Letters, 2 (4): 321–325, doi:10.1016/0893-9659(89)90079-7
- class brainpy.integrators.ode.adaptive_rk.HeunEuler(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]
The Heun–Euler method for ODEs.
The simplest adaptive Runge–Kutta method involves combining Heun’s method, which is order 2, with the Euler method, which is order 1.
It has the characteristics of:
method stage = 2
method order = 1
Butcher Tables:
\[\begin{split}\begin{array}{c|cc} 0&\\ 1& 1 \\ \hline & 1/2& 1/2\\ & 1 & 0 \end{array}\end{split}\]
Exponential Integrators
This module provides exponential integrators for ODEs.
Exponential integrators are a large class of methods from numerical analysis is based on the exact integration of the linear part of the initial value problem. Because the linear part is integrated exactly, this can help to mitigate the stiffness of a differential equation.
We consider initial value problems of the form,
which can be decomposed of
where \(L={\frac {\partial f}{\partial u}}\) (the Jacobian of f) is composed of linear terms, and \(N=f(u)-Lu\) is composed of the non-linear terms.
This procedure enjoys the advantage, in each step, that \({\frac {\partial N_{n}}{\partial u}}(u_{n})=0\). This considerably simplifies the derivation of the order conditions and improves the stability when integrating the nonlinearity \(N(u(t))\).
Exact integration of this problem from time 0 to a later time \(t\) can be performed using matrix exponentials to define an integral equation for the exact solution:
This representation of the exact solution is also called as variation-of-constant formula. In the case of \(N\equiv 0\), this formulation is the exact solution to the linear differential equation.
Exponential Rosenbrock methods
Exponential Rosenbrock methods were shown to be very efficient in solving large systems of stiff ODEs. Applying the variation-of-constants formula gives the exact solution at time \(t_{n+1}\) with the numerical solution \(u_n\) as
where \(h_n=t_{n+1}-t_n\).
The idea now is to approximate the integral in (1) by some quadrature rule with nodes \(c_{i}\) and weights \(b_{i}(h_{n}L)\) (\(1\leq i\leq s\)). This yields the following class of s-stage explicit exponential Rosenbrock methods:
where \(U_{ni}\approx u(t_{n}+c_{i}h_{n})\).
The coefficients \(a_{ij}(z),b_{i}(z)\) are usually chosen as linear combinations of the entire functions \(\varphi _{k}(c_{i}z),\varphi _{k}(z)\), respectively, where
By introducing the difference \(D_{ni}=N(U_{ni})-N(u_{n})\), they can be reformulated in a more efficient way for implementation as
where \(\varphi_{1}(z)=\frac{e^z-1}{z}\).
In order to implement this scheme with adaptive step size, one can consider, for the purpose of local error estimation, the following embedded methods
which use the same stages \(U_{ni}\) but with weights \({\bar {b}}_{i}\).
For convenience, the coefficients of the explicit exponential Rosenbrock methods together with their embedded methods can be represented by using the so-called reduced Butcher tableau as follows:
- 1
- 2
Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.
|
The exponential Euler method for ODEs. |
- class brainpy.integrators.ode.exponential.ExponentialEuler(f, var_type=None, dt=None, name=None, show_code=False)[source]
The exponential Euler method for ODEs.
The simplest exponential Rosenbrock method is the exponential Rosenbrock–Euler scheme, which has order 2.
For an ODE equation of the form
\[u^{\prime}=f(u), \quad u(0)=u_{0}\]its schema is given by
\[u_{n+1}= u_{n}+h \varphi(hL) f (u_{n})\]where \(L=f^{\prime}(u_{n})\) and \(\varphi(z)=\frac{e^{z}-1}{z}\).
For a linear ODE system: \(u^{\prime} = Ay + B\), the above equation is equal to \(u_{n+1}= u_{n}e^{hA}-B/A(1-e^{hA})\), which is the exact solution for this ODE system.
- Parameters
f (function) – The derivative function.
dt (optional, float) – The numerical precision.
var_type (optional, str) – The variable type.
show_code (bool) – Whether show the code.
Error Analysis of Numerical Methods
In order to identify the essential properties of numerical methods, we define basic notions 1.
For the given ODE system
we define \(y(t_n)\) as the solution of IVP evaluated at \(t=t_n\), and \(y_n\) is a numerical approximation of \(y(t_n)\) at the same location by a generic explicit numerical scheme (no matter explicit, implicit or multi-step scheme):
where \(h\) is the discretization step for \(t\), i.e., \(h=t_{n+1}-t_n\), and \(\phi(t_n,y_n,h)\) is the increment function. We say that the defined numerical scheme is consistent if \(\lim_{h\to0} \phi(t,y,h) = \phi(t,y,0) = f(t,y)\).
Then, the approximation error is defined as
The absolute error is defined as
The relative error is defined as
The exact differential operator is defined as
The approximate differential operator is defined as
Finally, the local truncation error (LTE) is defined as
In practice, the evaluation of the exact solution for different \(t\) around \(t_n\) (required by \(L_a\)) is performed using a Taylor series expansion.
Finally, we can state that a scheme is \(p\)-th order accurate by examining its LTE and observing its leading term
where \(C\) is a constant, independent of \(h\), and \(H.O.T.\) are the higher order terms of the LTE.
Example: LTE for Euler’s scheme
Consider the IVP defined by \(y' = \lambda y\), with initial condition \(y(0)=1\).
The approximation operator for Euler’s scheme is
then the LTE can be computed by
where we assume \(y_n = y(t_n)\).
Numerical Methods for SDEs
Numerical methods for stochastic differential equations.
|
|
|
|
|
|
|
First order, explicit exponential Euler method. |
|
Order 2.0 weak SRK methods for SDEs with scalar Wiener process. |
|
Order 1.5 Strong SRK Methods for SDEs witdt Scalar Noise. |
|
- class brainpy.integrators.sde.Euler(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
- class brainpy.integrators.sde.Heun(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
- class brainpy.integrators.sde.Milstein(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
- class brainpy.integrators.sde.ExponentialEuler(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
First order, explicit exponential Euler method.
For a SDE equation of the form
\[d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \quad y(0)=y_{0}\]its schema is given by [1]_
\[\begin{split}y_{n+1} & =e^{\Delta t A}(y_{n}+ g(y_n)\Delta W_{n})+\varphi(\Delta t A) F(y_{n}) \Delta t \\ &= y_n + \Delta t \varphi(\Delta t A) f(y) + e^{\Delta t A}g(y_n)\Delta W_{n}\end{split}\]where \(\varphi(z)=\frac{e^{z}-1}{z}\).
References
- 1
Erdoğan, Utku, and Gabriel J. Lord. “A new class of exponential integrators for stochastic differential equations with multiplicative noise.” arXiv preprint arXiv:1608.07096 (2016).
- class brainpy.integrators.sde.SRK1W1(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
Order 2.0 weak SRK methods for SDEs with scalar Wiener process.
This method has have strong orders :backend:`(p_d, p_s) = (2.0,1.5)`.
The Butcher table is:
\[\begin{split}\begin{array}{l|llll|llll|llll} 0 &&&&& &&&& &&&& \\ 3/4 &3/4&&&& 3/2&&& &&&& \\ 0 &0&0&0&& 0&0&0&& &&&&\\ \hline 0 \\ 1/4 & 1/4&&& & 1/2&&&\\ 1 & 1&0&&& -1&0&\\ 1/4& 0&0&1/4&& -5&3&1/2\\ \hline & 1/3& 2/3& 0 & 0 & -1 & 4/3 & 2/3&0 & -1 &4/3 &-1/3 &0 \\ \hline & &&&& 2 &-4/3 & -2/3 & 0 & -2 & 5/3 & -2/3 & 1 \end{array}\end{split}\]References
- 1
Rößler, Andreas. “Strong and weak approximation methods for stochastic differential equations—some recent developments.” Recent developments in applied probability and statistics. Physica-Verlag HD, 2010. 127-153.
- 2
Rößler, Andreas. “Runge–Kutta methods for the strong approximation of solutions of stochastic differential equations.” SIAM Journal on Numerical Analysis 48.3 (2010): 922-952.
- class brainpy.integrators.sde.SRK2W1(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
Order 1.5 Strong SRK Methods for SDEs witdt Scalar Noise.
This method has have strong orders :backend:`(p_d, p_s) = (3.0,1.5)`.
The Butcher table is:
\[egin{array}{c|cccc|cccc|ccc|} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 / 2 & 1 / 4 & 1 / 4 & 0 & 0 & 1 & 1 / 2 & 0 & 0 & & & & \ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ \hline 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 / 4 & 1 / 4 & 0 & 0 & 0 & -1 / 2 & 0 & 0 & 0 & & & & \ 1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & & & & \ 1 / 4 & 0 & 0 & 1 / 4 & 0 & 2 & -1 & 1 / 2 & 0 & & & & \ \hline & 1 / 6 & 1 / 6 & 2 / 3 & 0 & -1 & 4 / 3 & 2 / 3 & 0 & -1 & -4 / 3 & 1 / 3 & 0 \ \hline & & & & &2 & -4 / 3 & -2 / 3 & 0 & -2 & 5 / 3 & -2 / 3 & 1 \end{array}\]References
- [1] Rößler, Andreas. “Strong and weak approximation methods for stochastic differential
equations—some recent developments.” Recent developments in applied probability and statistics. Physica-Verlag HD, 2010. 127-153.
- [2] Rößler, Andreas. “Runge–Kutta methods for the strong approximation of solutions of
stochastic differential equations.” SIAM Journal on Numerical Analysis 48.3 (2010): 922-952.
brainpy.simulation
module
This module provides APIs for brain simulations.
Brain Objects
This module provides various interface to model brain objects.
You can access them through brainpy.XXX
or brainpy.brainobjects.XXX
.
|
|
|
Base Dynamical System class. |
|
Container object which is designed to add other instances of DynamicalSystem. |
|
Base class to model delay variables. |
|
Class used to model constant delay variables. |
|
The input neuron group characterized by spikes emitting at given times. |
|
Poisson Neuron Group. |
|
Base class to model molecular objects. |
|
Base class to model network objects, an alias of Container. |
|
Base class to model neuronal groups. |
|
Base class to model ion channels. |
|
Base class to model soma in multi-compartment neuron models. |
|
Base class to model dendrites. |
|
Base class to model two-end synaptic connections. |
- class brainpy.simulation.brainobjects.BrainArea(*ds_tuple, steps=None, monitors=None, name=None, **ds_dict)[source]
- class brainpy.simulation.brainobjects.DynamicalSystem(steps=None, monitors=None, name=None)[source]
Base Dynamical System class.
Any object has step functions will be a dynamical system. That is to say, in BrainPy, the essence of the dynamical system is the “step functions”.
- Parameters
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- register_constant_delay(key, size, delay, dtype=None)[source]
Register a constant delay.
- Parameters
key (str) – The delay name.
size (int, list of int, tuple of int) – The delay data size.
delay (int, float, ndarray) – The delay time, with the unit same with brainpy.math.get_dt().
dtype (optional) – The data type.
- Returns
delay – An instance of ConstantDelay.
- Return type
- run(duration, dt=None, report=0.0, inputs=(), extra_func=None)[source]
The running function.
- Parameters
inputs (list, tuple) –
The inputs for this instance of DynamicalSystem. It should the format of [(target, value, [type, operation])], where target is the input target, value is the input value, type is the input type (such as “fix” or “iter”), operation is the operation for inputs (such as “+”, “-”, “*”, “/”, “=”).
target
: should be a string. Can be specified by the absolute access or relative access.value
: should be a scalar, vector, matrix, iterable function or objects.type
: should be a string. “fix” means the input value is a constant. “iter” means the input value can be changed over time.operation
: should be a string, support +, -, *, /, =.Also, if you want to specify multiple inputs, just give multiple
(target, value, [type, operation])
, for example[(target1, value1), (target2, value2)]
.
duration (float, int, tuple, list) – The running duration.
report (float) – The percent of progress to report. [0, 1]. If zero, the model will not output report progress.
dt (float, optional) – The numerical integration step size.
extra_func (function, callable) – The extra function to run during each time step.
method (optional, str) – The method to run the model.
- Returns
running_time – The total running time.
- Return type
float
- class brainpy.simulation.brainobjects.Container(*ds_tuple, steps=None, monitors=None, name=None, **ds_dict)[source]
Container object which is designed to add other instances of DynamicalSystem.
- Parameters
steps (tuple of function, tuple of str, dict of (str, function), optional) – The step functions.
monitors (tuple, list, Monitor, optional) – The monitor object.
name (str, optional) – The object name.
show_code (bool) – Whether show the formatted code.
ds_dict (dict of (str, )) – The instance of DynamicalSystem with the format of “key=dynamic_system”.
- class brainpy.simulation.brainobjects.Delay(steps=('update',), name=None, monitors=None)[source]
Base class to model delay variables.
- Parameters
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.ConstantDelay(size, delay, dtype=None, dt=None, **kwargs)[source]
Class used to model constant delay variables.
This class automatically supports batch size on the last axis. For example, if you run batch with the size of (10, 100), where 100 are batch size, then this class can automatically support your batched data.
For examples:
>>> import brainpy as bp >>> >>> bp.ConstantDelay(size=10, delay=10.) >>> bp.ConstantDelay(size=100, delay=bp.math.random.random(100) * 4 + 10)
- Parameters
size (int, list of int, tuple of int) – The delay data size.
delay (int, float, function, ndarray) – The delay time. With the unit of dt.
num_batch (optional, int) – The batch size.
steps (optional, tuple of str, tuple of function, dict of (str, function)) – The callable function, or a list of callable functions.
monitors (optional, list, tuple, datastructures.Monitor) – Variables to monitor.
name (optional, str) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.SpikeTimeInput(size, times, indices, need_sort=True, **kwargs)[source]
The input neuron group characterized by spikes emitting at given times.
>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. >>> SpikeTimeInput(2, times=[10, 20]) >>> # or >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. >>> SpikeTimeInput(2, times=[10, 20], indices=[0, 0]) >>> # or >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. >>> SpikeTimeInput(2, times=[10, 20, 30], indices=[0, 1, 0]) >>> # or >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; >>> # at 30 ms, neuron 1 fires. >>> SpikeTimeInput(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
- Parameters
size (int, tuple, list) – The neuron group geometry.
indices (int, list, tuple) – The neuron indices at each time point to emit spikes.
times (list, np.ndarray) – The time points which generate the spikes.
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.PoissonInput(size, freqs, seed=None, **kwargs)[source]
Poisson Neuron Group.
- Parameters
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.Molecular(name, **kwargs)[source]
Base class to model molecular objects.
- Parameters
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.Network(*ds_tuple, monitors=None, name=None, **ds_dict)[source]
Base class to model network objects, an alias of Container.
Network instantiates a network, which is aimed to load neurons, synapses, and other brain objects.
- Parameters
name (str, Optional) – The network name.
monitors (optional, list of str, tuple of str) – The items to monitor.
ds_tuple – A list/tuple container of dynamical system.
ds_dict – A dict container of dynamical system.
- class brainpy.simulation.brainobjects.NeuGroup(size, name=None, steps=('update',), **kwargs)[source]
Base class to model neuronal groups.
There are several essential attributes:
size
: the geometry of the neuron group. For example, (10, ) denotes a line of neurons, (10, 10) denotes a neuron group aligned in a 2D space, (10, 15, 4) denotes a 3-dimensional neuron group.num
: the flattened number of neurons in the group. For example, size=(10, ) => num=10, size=(10, 10) => num=100, size=(10, 15, 4) => num=600.shape
: the variable shape with (num, num_batch).
- Parameters
size (int, tuple of int, list of int) – The neuron group geometry.
num_batch (optional, int) – The batch size.
steps (tuple of str, tuple of function, dict of (str, function), optional) – The step functions.
steps – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (optional, str) – The name of the dynamic system.
- class brainpy.simulation.brainobjects.Channel(**kwargs)[source]
Base class to model ion channels.
Notes
The
__init__()
function inChannel
is used to specify the parameters of the channel. The__call__()
function is used to initialize the variables in this channel.
- class brainpy.simulation.brainobjects.Soma(name, **kwargs)[source]
Base class to model soma in multi-compartment neuron models.
- class brainpy.simulation.brainobjects.Dendrite(name, **kwargs)[source]
Base class to model dendrites.
- class brainpy.simulation.brainobjects.TwoEndConn(pre, post, conn=None, name=None, steps=('update',), **kwargs)[source]
Base class to model two-end synaptic connections.
- Parameters
steps (tuple of str, tuple of function, dict of (str, function), optional) – The step functions.
pre (NeuGroup) – Pre-synaptic neuron group.
post (NeuGroup) – Post-synaptic neuron group.
conn (math.ndarray, dict of (str, math.ndarray), TwoEndConnector) – The connection method between pre- and post-synaptic groups.
steps – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
show_code (bool) – Whether show the formatted code.
DNN Layers
This module provides various interfaces to model DNN layers.
You can access them through brainpy.layers.XXX
.
|
Basic module class for DNN networks. |
|
Basic sequential object to control data flow. |
|
Activation Layer. |
|
Apply a 2D convolution on a 4D-input batch of shape (N,C,H,W). |
|
A fully connected layer implemented as the dot product of inputs and weights. |
|
A layer that stochastically ignores a subset of inputs each training step. |
|
Neuron group to readout information linearly. |
|
|
|
Basic fully-connected RNN core. |
|
Gated Recurrent Unit. |
|
Long short-term memory (LSTM) RNN core. |
- class brainpy.simulation.layers.Module(steps=None, monitors=None, name=None)[source]
Basic module class for DNN networks.
- target_backend = 'jax'
Used to specify the target backend which the model to run.
- class brainpy.simulation.layers.Sequential(*arg_ds, monitors=None, name=None, **kwarg_ds)[source]
Basic sequential object to control data flow.
- Parameters
arg_ds – The modules without name specifications.
name (str, optional) – The name of the sequential module.
kwarg_ds – The modules with name specifications.
- update(*args, **kwargs)[source]
Functional call.
- Parameters
args (list, tuple) – The *args arguments.
kwargs (dict) – The config arguments. The configuration used across modules. If the “__call__” function in submodule receives “config” arguments, This “config” parameter will be passed into this function.
- class brainpy.simulation.layers.Activation(activation, name=None, **setting)[source]
Activation Layer.
- Parameters
activation (str) – The name of the activation function.
name (optional, str) – The name of the class.
setting (Any) – The settings for the activation function.
- class brainpy.simulation.layers.Conv2D(num_input, num_output, kernel_size, strides=1, dilations=1, groups=1, padding='SAME', w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]
Apply a 2D convolution on a 4D-input batch of shape (N,C,H,W).
- Parameters
num_input (int) – The number of channels of the input tensor.
num_output (int) – The number of channels of the output tensor.
kernel_size (int, tuple of int) – The size of the convolution kernel, either tuple (height, width) or single number if they’re the same.
strides (int, tuple of int) – The convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.
dilations (int, tuple of int) – The spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.
groups (int) – The number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.
padding (int, str) – The padding of the input tensor, either “SAME”, “VALID” or numerical values (low, high).
w (Initializer, JaxArray, jax.numpy.ndarray) – The initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
b (Initializer, JaxArray, jax.numpy.ndarray, optional) – The bias initialization.
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.layers.Dense(num_hidden, num_input, w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]
A fully connected layer implemented as the dot product of inputs and weights.
- Parameters
num_hidden (int) – The neuron group size.
num_input (int) – The input size.
w (Initializer, JaxArray, jax.numpy.ndarray) – Initializer for the weights.
b (Initializer, JaxArray, jax.numpy.ndarray, optional) – Initializer for the bias.
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.layers.Dropout(prob, seed=None, **kwargs)[source]
A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).
The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.
This layer is active only during training (mode=’train’). In other circumstances it is a no-op.
Originally introduced in the paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting” available under the following link: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf
- Parameters
prob (float) – Probability to keep element of the tensor.
steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.
monitors (None, list, tuple, Monitor) – Variables to monitor.
name (str, optional) – The name of the dynamic system.
- class brainpy.simulation.layers.LinearReadout(num_hidden, num_input, num_batch=1, w_init=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b_init=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, has_bias=True, s_init=<brainpy.simulation.initialize.random_inits.Uniform object>, train_mask=None, **kwargs)[source]
Neuron group to readout information linearly.
- Parameters
num_hidden (int) – The neuron group size.
num_input (int) – The input size.
w_init (Initializer) – Initializer for the weights.
b_init (Initializer) – Initializer for the bias.
has_bias (bool) – Whether has the bias to compute.
s_init (Initializer) – Initializer for variable states.
train_mask (optional, math.ndarray) – The training mask for the weights.
- class brainpy.simulation.layers.VanillaRNN(num_hidden, num_input, num_batch, h=<brainpy.simulation.initialize.random_inits.Uniform object>, w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]
Basic fully-connected RNN core.
Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes
\[h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]The output is equal to the new state, \(h_t\).
- class brainpy.simulation.layers.GRU(num_hidden, num_input, num_batch, wx=<brainpy.simulation.initialize.random_inits.Orthogonal object>, wh=<brainpy.simulation.initialize.random_inits.Orthogonal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, h=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]
Gated Recurrent Unit.
The implementation is based on (Chung, et al., 2014) [1]_ with biases.
Given \(x_t\) and the previous state \(h_{t-1}\) the core computes
\[\begin{split}\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\end{split}\]where \(z_t\) and \(r_t\) are reset and update gates.
The output is equal to the new hidden state, \(h_t\).
Warning: Backwards compatibility of GRU weights is currently unsupported.
References
- 1
Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.
- class brainpy.simulation.layers.LSTM(num_hidden, num_input, num_batch, w=<brainpy.simulation.initialize.random_inits.Orthogonal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, hc=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]
Long short-term memory (LSTM) RNN core.
The implementation is based on (zaremba, et al., 2014) [1]_. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes
\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
The output is equal to the new hidden, \(h_t\).
Notes
Forget gate initialization: Following (Jozefowicz, et al., 2015) 2 we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.
References
- 1
Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. “Recurrent neural network regularization.” arXiv preprint arXiv:1409.2329 (2014).
- 2
Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. “An empirical exploration of recurrent network architectures.” In International conference on machine learning, pp. 2342-2350. PMLR, 2015.
Synaptic Connectivity
This module provides methods to construct connectivity between neuron groups.
You can access them through brainpy.connect.XXX
.
Base Class
Base Synaptical Connector Class. |
|
Synaptical connector to build synapse connections between two neuron groups. |
|
Synaptical connector to build synapse connections within a population of neurons. |
Custom Connections
|
Connector built from the connection matrix. |
|
Connector built from the |
Random Connections
|
Connect the post-synaptic neurons with fixed probability. |
|
Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron. |
|
Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron. |
|
Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function. |
|
Build a Watts–Strogatz small-world graph. |
|
Build a random graph according to the Barabási–Albert preferential attachment model. |
|
Build a random graph according to the dual Barabási–Albert preferential attachment model. |
|
Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering. |
- class brainpy.simulation.connect.FixedProb(prob, include_self=True, seed=None)[source]
Connect the post-synaptic neurons with fixed probability.
- Parameters
prob (float) – The conn probability.
include_self (bool) – Whether create (i, i) conn?
seed (optional, int) – Seed the random generator.
- class brainpy.simulation.connect.FixedPreNum(num, include_self=True, seed=None)[source]
Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.
- Parameters
num (float, int) – The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).
include_self (bool) – Whether create (i, i) conn ?
seed (None, int) – Seed the random generator.
method (str) –
The method used to create the connection.
matrix
: This method will create a big matrix, then, the connectivity is constructed from this matrix \((N_{pre}, N_{post})\). In a large network, this method will consume huge memories, including a matrix: \((N_{pre}, N_{post})\), two vectors: \(2 * N_{need} * N_{post}\).iter
: This method will iteratively build the synaptic connections. It has the minimum pressure of memory consuming, only \(2 * N_{need} * N_{post}\) (i
andj
vectors).
- class brainpy.simulation.connect.FixedPostNum(num, include_self=True, seed=None)[source]
Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.
- Parameters
num (float, int) – The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).
include_self (bool) – Whether create (i, i) conn ?
seed (None, int) – Seed the random generator.
method (str) –
The method used to create the connection.
matrix
: This method will create a big matrix, then, the connectivity is constructed from this matrix \((N_{pre}, N_{post})\). In a large network, this method will consume huge memories, including a matrix: \((N_{pre}, N_{post})\), two vectors: \(2 * N_{need} * N_{pre}\).iter
: This method will iteratively build the synaptic connections. It has the minimum pressure of memory consuming, only \(2 * N_{need} * N_{pre}\) (i
andj
vectors).
- class brainpy.simulation.connect.GaussianProb(sigma, encoding_values=None, normalize=True, include_self=True, periodic_boundary=False, seed=None)[source]
Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function.
Specifically, for any pair of neurons \((i, j)\),
\[p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2})\]where \(v_k^i\) is the $i$-th neuron’s encoded value at dimension $k$.
- Parameters
sigma (float) – Width of the Gaussian function.
encoding_values (optional, list, tuple, int, float) –
The value ranges to encode for neurons at each axis.
If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.
If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].
If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).
periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.
normalize (bool) – Whether normalize the connection probability .
include_self (bool) – Whether create the conn at the same position.
seed (bool) – The random seed.
- class brainpy.simulation.connect.SmallWorld(num_neighbor, prob, directed=False, include_self=False)[source]
Build a Watts–Strogatz small-world graph.
- Parameters
num_neighbor (int) – Each node is joined with its k nearest neighbors in a ring topology.
prob (float) – The probability of rewiring each edge
directed (bool) – Whether the graph is a directed graph.
include_self (bool) – Whether include the node self.
Notes
First create a ring over \(num\_node\) nodes [1]_. Then each node in the ring is joined to its \(num\_neighbor\) nearest neighbors (or \(num\_neighbor - 1\) neighbors if \(num\_neighbor\) is odd). Then shortcuts are created by replacing some edges as follows: for each edge \((u, v)\) in the underlying “\(num\_node\)-ring with \(num\_neighbor\) nearest neighbors” with probability \(prob\) replace it with a new edge \((u, w)\) with uniformly random choice of existing node \(w\).
References
- 1
Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.
- class brainpy.simulation.connect.ScaleFreeBA(m, directed=False, seed=None)[source]
Build a random graph according to the Barabási–Albert preferential attachment model.
A graph of \(num\_node\) nodes is grown by attaching new nodes each with \(m\) edges that are preferentially attached to existing nodes with high degree.
- Parameters
m (int) – Number of edges to attach from a new node to existing nodes
seed (integer, random_state, or None (default)) – Indicator of random number generation state.
- Raises
ValueError – If m does not satisfy
1 <= m < n
.
References
- 1
A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.
- class brainpy.simulation.connect.ScaleFreeBADual(m1, m2, p, directed=False, seed=None)[source]
Build a random graph according to the dual Barabási–Albert preferential attachment model.
A graph of :math::num_node nodes is grown by attaching new nodes each with either $m_1$ edges (with probability \(p\)) or \(m_2\) edges (with probability \(1-p\)) that are preferentially attached to existing nodes with high degree.
- Parameters
m1 (int) – Number of edges to attach from a new node to existing nodes with probability $p$
m2 (int) – Number of edges to attach from a new node to existing nodes with probability $1-p$
p (float) – The probability of attaching $m_1$ edges (as opposed to $m_2$ edges)
seed (integer, random_state, or None (default)) – Indicator of random number generation state.
- Raises
ValueError – If m1 and m2 do not satisfy
1 <= m1,m2 < n
or p does not satisfy0 <= p <= 1
.
References
- 1
Moshiri “The dual-Barabasi-Albert model”, arXiv:1810.10538.
- class brainpy.simulation.connect.PowerLaw(m, p, directed=False, seed=None)[source]
Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering.
- Parameters
m (int) – the number of random edges to add for each new node
p (float,) – Probability of adding a triangle after adding a random edge
seed (integer, random_state, or None (default)) – Indicator of random number generation state.
Notes
The average clustering has a hard time getting above a certain cutoff that depends on \(m\). This cutoff is often quite low. The transitivity (fraction of triangles to possible triangles) seems to decrease with network size.
It is essentially the Barabási–Albert (BA) growth model with an extra step that each random edge is followed by a chance of making an edge to one of its neighbors too (and thus a triangle).
This algorithm improves on BA in the sense that it enables a higher average clustering to be attained if desired.
It seems possible to have a disconnected graph with this algorithm since the initial \(m\) nodes may not be all linked to a new node on the first iteration like the BA model.
- Raises
ValueError – If \(m\) does not satisfy \(1 <= m <= n\) or \(p\) does not satisfy \(0 <= p <= 1\).
References
- 1
P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.
Regular Connections
|
Connect two neuron groups one by one. |
|
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. |
|
The nearest four neighbors conn method. |
|
The nearest eight neighbors conn method. |
|
The nearest (2*N+1) * (2*N+1) neighbors conn method. |
- class brainpy.simulation.connect.One2One[source]
Connect two neuron groups one by one. This means The two neuron groups should have the same size.
- class brainpy.simulation.connect.All2All(include_self=True)[source]
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. It means this kind of conn will create (num_pre x num_post) synapses.
- class brainpy.simulation.connect.GridFour(include_self=False)[source]
The nearest four neighbors conn method.
- class brainpy.simulation.connect.GridEight(include_self=False)[source]
The nearest eight neighbors conn method.
- class brainpy.simulation.connect.GridN(N=1, include_self=False)[source]
The nearest (2*N+1) * (2*N+1) neighbors conn method.
- Parameters
N (int) –
Extend of the conn scope. For example: When N=1,
[x x x] [x I x] [x x x]
- When N=2,
[x x x x x] [x x x x x] [x x I x x] [x x x x x] [x x x x x]
include_self (bool) – Whether create (i, i) conn ?
Formatter Functions
|
Convert i-j connection to matrix connection. |
|
Get the i-j connections from connectivity matrix. |
|
Get pre2post connections from i and j indexes. |
|
Get post2pre connections from i and j indexes. |
|
Get pre2syn connections from i and j indexes. |
|
Get post2syn connections from i and j indexes. |
|
Get post slicing connections by pre-synaptic ids. |
|
Get pre slicing connections by post-synaptic ids. |
Weight Initialization
This module provides methods to initialize weights.
You can access them through brainpy.initialize.XXX
.
Base Class
Base Initialization Class. |
Regular Initializers
|
Zero initializer. |
|
One initializer. |
|
Returns the identity matrix. |
- class brainpy.simulation.initialize.ZeroInit[source]
Zero initializer.
Initialize the weights with zeros.
- class brainpy.simulation.initialize.OneInit(value=1.0)[source]
One initializer.
Initialize the weights with the given values.
- Parameters
value (float, int, math.ndarray) – The value to specify.
- class brainpy.simulation.initialize.Identity(value=1.0)[source]
Returns the identity matrix.
This initializer was proposed in (Le, et al., 2015) 1.
- Parameters
value (float) – The optional scaling factor.
- Returns
shape – The weight shape/size.
- Return type
tuple of int
References
- 1
Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. “A simple way to initialize recurrent networks of rectified linear units.” arXiv preprint arXiv:1504.00941 (2015).
Random Initializers
|
Initialize weights with normal distribution. |
|
Initialize weights with uniform distribution. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Construct an initializer for uniformly distributed orthogonal matrices. |
|
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. |
- class brainpy.simulation.initialize.Normal(scale=1.0)[source]
Initialize weights with normal distribution.
- Parameters
gain (float) – The gain of the derivation of the normal distribution.
- class brainpy.simulation.initialize.Uniform(min_val=0.0, max_val=1.0, scale=0.01)[source]
Initialize weights with uniform distribution.
- Parameters
min_val (float) – The lower limit of the uniform distribution.
max_val (float) – The upper limit of the uniform distribution.
- class brainpy.simulation.initialize.VarianceScaling(scale, mode, distribution, in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.KaimingUniform(scale=2.0, mode='fan_in', distribution='uniform', in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.KaimingNormal(scale=2.0, mode='fan_in', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.XavierUniform(scale=1.0, mode='fan_avg', distribution='uniform', in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.XavierNormal(scale=1.0, mode='fan_avg', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.LecunUniform(scale=1.0, mode='fan_in', distribution='uniform', in_axis=-2, out_axis=-1)[source]
- class brainpy.simulation.initialize.LecunNormal(scale=1.0, mode='fan_in', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
Decay Initializers
|
Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function. |
|
Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. |
- class brainpy.simulation.initialize.GaussianDecay(sigma, max_w, min_w=None, encoding_values=None, periodic_boundary=False, include_self=True, normalize=False)[source]
Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function.
Specifically, for any pair of neurons \((i, j)\), the weight is computed as
\[w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2})\]where \(v_k^i\) is the $i$-th neuron’s encoded value at dimension $k$.
- Parameters
sigma (float) – Width of the Gaussian function.
max_w (float) – The weight amplitude of the Gaussian function.
min_w (float, None) – The minimum weight value below which synapses are not created (default: \(0.005 * max\_w\)).
include_self (bool) – Whether create the conn at the same position.
encoding_values (optional, list, tuple, int, float) –
The value ranges to encode for neurons at each axis.
If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.
If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].
If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).
periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.
normalize (bool) – Whether normalize the connection probability.
- class brainpy.simulation.initialize.DOGDecay(sigmas, max_ws, min_w=None, encoding_values=None, periodic_boundary=False, normalize=True, include_self=True)[source]
Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons.
Mathematically, for the given pair of neurons \((i, j)\), the weight between them is computed as
\[w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2})\]where weights smaller than \(0.005 * max(w_{max}, w_{min})\) are not created and self-connections are avoided by default (parameter allow_self_connections).
- Parameters
sigmas (tuple) – Widths of the positive and negative Gaussian functions.
max_ws (tuple) – The weight amplitudes of the positive and negative Gaussian functions.
min_ws (float, None) – The minimum weight value below which synapses are not created (default: \(0.005 * max(max\_ws)\)).
include_self (bool) – Whether create the conn at the same position.
normalize (bool) – Whether normalize the connection probability .
encoding_values (optional, list, tuple, int, float) –
The value ranges to encode for neurons at each axis.
If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.
If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].
If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).
periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.
Current Inputs
This module provides various methods to form current inputs.
You can access them through brainpy.inputs.XXX
.
|
Format an input current with different sections. |
|
Format constant input in durations. |
|
Format constant input in durations. |
|
Format current input like a series of short-time spikes. |
|
Format current input like a series of short-time spikes. |
|
Get the gradually changed input current. |
|
Get the gradually changed input current. |
Measurements
This module aims to provide commonly used analysis methods for simulated neuronal data.
You can access them through brainpy.measure.XXX
.
|
Calculate cross correlation index between neurons. |
|
Calculate neuronal synchronization via voltage variance. |
|
Get spike raster plot which displays the spiking activity of a group of neurons over time. |
|
Calculate the mean firing rate over in a neuron group. |
Monitors
|
The basic Monitor class to store the past variable trajectories. |
- class brainpy.simulation.monitor.Monitor(variables, intervals=None, target=None)[source]
The basic Monitor class to store the past variable trajectories.
Currently,
brainpy.simulation.Monitor
support to specify:variable key by strings.
variable index by None, int, list, tuple, 1D array/tensor (==> all will be transformed into a 1D array/tensor)
variable monitor interval by None, int, float
Users can instance a monitor object by multiple ways:
list of strings.
>>> Monitor(target=..., variables=['a', 'b', 'c'])
1.1. list of strings and list of intervals
>>> Monitor(target=..., variables=['a', 'b', 'c'], >>> intervals=[None, 1, 2] # ms >>> )
list of strings and string + indices
>>> Monitor(target=..., variables=['a', ('b', math.array([1,2,3])), 'c'])
2.1. list of string (+ indices) and list of intervals
>>> Monitor(target=..., variables=['a', ('b', math.array([1,2,3])), 'c'], >>> intervals=[None, 2, 3])
a dictionary with the format of {key: indices}
>>> Monitor(target=..., variables={'a': None, 'b': math.array([1,2,3])})
3.1. a dictionaly of variable and indexes, and a dictionary of time intervals
>>> Monitor(target=..., variables={'a': None, 'b': math.array([1,2,3])}, >>> intervals={'b': 2.})
Note
brainpy.simulation.Monitor
records any target variable with an two-dimensional array/tensor with the shape of (num_time_step, variable_size). This means for any variable, no matter what’s the shape of the data (int, float, vector, matrix, 3D array/tensor), will be reshaped into a one-dimensional vector.
brainpy.analysis
module
This module provides analysis tools for differential equations.
The
symbolic
module use SymPy symbolic inference to make analysis of low-dimensional dynamical system (only sypport ODEs).The
numeric
module use numerical optimization function to make analysis of high-dimensional dynamical system (support ODEs and discrete systems).The
continuation
module is the analysis package with numerical continuation methods.Moreover, we provide several useful functions in
stability
module which may help your dynamical system analysis, like:>>> get_1d_stability_types() ['saddle node', 'stable point', 'unstable point']
Details in the following.
Dynamics Analysis (Symbolic)
Dynamics analysis with the aid of SymPy symbolic inference.
This module provide basic dynamics analysis for low-dimensional dynamical systems, including
phase plane analysis (1d or 2d systems)
bifurcation analysis (1d or 2d systems)
fast slow bifurcation analysis (2d or 3d systems)
|
Dynamics Analyzer for Neuron Models. |
|
Neuron analysis analyzer for 1D system. |
|
Neuron analysis analyzer for 2D system. |
|
A tool class for bifurcation analysis. |
|
|
|
A tool class for phase plane analysis. |
- class brainpy.analysis.symbolic.BaseSymAnalyzer(model_or_integrals, target_vars, fixed_vars=None, target_pars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]
Dynamics Analyzer for Neuron Models.
This class is a base class which aims for analyze the analysis in neuron models. A neuron model is characterized by a series of dynamical variables and parameters:
\[{dF \over dt} = F(v_1, v_2, ..., p_1, p_2, ...)\]where \(v_1, v_2\) are variables, \(p_1, p_2\) are parameters.
- Parameters
model_or_integrals (Any) – A model of the population, the integrator function, or a list/tuple of integrator functions.
target_vars (dict) – The target/dynamical variables.
fixed_vars (dict) – The fixed variables.
target_pars (dict, optional) – The parameters which can be dynamical varied.
pars_update (dict, optional) – The parameters to update.
numerical_resolution (float, dict) –
The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example,
set
numerical_resolution=0.1
will generalize it to all variables and parameters;set
numerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05}
will specify the particular resolutions to variables and parameters.Moreover, you can also set
numerical_resolution={var1: np.array([...]), var2: 0.1}
to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.
options (dict, optional) –
The other setting parameters, which includes:
perturbation: float. The small perturbation used to solve the function derivative.
sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.
escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.
lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to
[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]
.
- class brainpy.analysis.symbolic.Base1DSymAnalyzer(*args, **kwargs)[source]
Neuron analysis analyzer for 1D system.
It supports the analysis of 1D dynamical system.
\[{dx \over dt} = f(x, t)\]
- class brainpy.analysis.symbolic.Base2DSymAnalyzer(*args, **kwargs)[source]
Neuron analysis analyzer for 2D system.
It supports the analysis of 2D dynamical system.
\[ \begin{align}\begin{aligned}{dx \over dt} = f(x, t, y)\\{dy \over dt} = g(y, t, x)\end{aligned}\end{align} \]- Parameters
options (dict, optional) –
The other setting parameters, which includes:
shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.
show_shgo: bool. whether print the shgo’s value.
fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.
xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.
- get_f_optimize_x_nullcline(coords=None)[source]
Get the function to solve X nullcline by using numerical optimization method.
- Parameters
coords (str) – The coordination.
- get_f_optimize_y_nullcline(coords=None)[source]
Get the function to solve Y nullcline by using numerical optimization method.
- Parameters
coords (str) – The coordination.
- get_x_by_y_in_x_eq()[source]
Get the expression of “x_by_y_in_x_eq”.
Specifically,
self.analyzed_results['x_by_y_in_x_eq']
is a Dict, with the following keywords:status : ‘sympy_success’, ‘sympy_failed’, ‘escape’
subs : substituted expressions (relationship) of x_by_y
f : function of x_by_y
- get_x_by_y_in_y_eq()[source]
Get the expression of “x_by_y_in_y_eq”.
Specifically,
self.analyzed_results['x_by_y_in_y_eq']
is a Dict, with the following keywords:status : ‘sympy_success’, ‘sympy_failed’, ‘escape’
subs : substituted expressions (relationship) of x_by_y
f : function of x_by_y
- class brainpy.analysis.symbolic.Bifurcation(integrals, target_pars, target_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]
A tool class for bifurcation analysis.
The bifurcation analyzer is restricted to analyze the bifurcation relation between membrane potential and a given model parameter (co-dimension-1 case) or two model parameters (co-dimension-2 case).
Externally injected current is also treated as a model parameter in this class, instead of a model state.
Examples
Tutorials please see: Dynamics Analysis (Symbolic)
- Parameters
integrals (function, functions) – The integral functions defined with brainpy.odeint or brainpy.sdeint or brainpy.ddeint, or brainpy.fdeint.
target_vars (dict) – The target dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.
fixed_vars (dict, optional) – The fixed variables. It must a fixed value with the format of {‘var1’: value}.
target_pars (dict, optional) – The parameters which can be dynamical varied. It must be a dictionary which specifies the boundary of the variables: {‘par1’: [min, max]}
pars_update (dict, optional) – The parameters to update. Or, they can be treated as staitic parameters. Same with the fixed_vars, they are must fixed values with the format of {‘par1’: value}.
numerical_resolution (float, dict, optional) –
The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example,
set
numerical_resolution=0.1
will generalize it to all variables and parameters;set
numerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05}
will specify the particular resolutions to variables and parameters.Moreover, you can also set
numerical_resolution={var1: np.array([...]), var2: 0.1}
to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.
options (dict, optional) –
The other setting parameters, which includes:
perturbation: float. The small perturbation used to solve the function derivatives.
sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.
escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.
lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to
[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]
.
The parameters which are usefull for two-dimensional bifurcation analysis:
shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.
show_shgo: bool. whether print the shgo’s value.
fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.
xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.
- plot_bifurcation(*args, **kwargs)[source]
Plot bifurcation, which support bifurcation analysis of co-dimension 1 and co-dimension 2.
- Parameters
show (bool) – Whether show the bifurcation figure.
- Returns
points – The bifurcation points which specifies their fixed points and corresponding stability.
- Return type
dict
- plot_limit_cycle_by_sim(var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False)[source]
Plot limit cycles by the simulation results.
This function help users plot the limit cycles through the simulation results, in which the periodic signals will be automatically found and then treated them as the candidate of limit cycles.
- Parameters
var (str) – The target variable to found its limit cycles.
duration (int, float, tuple, list) – The simulation duration.
inputs (tuple, list) – The simulation inputs.
plot_style (dict) – The limit cycle plotting style settings.
tol (float) – The tolerance to found periodic signals.
show (bool) – Whether show the figure.
- class brainpy.analysis.symbolic.FastSlowBifurcation(integrals, fast_vars, slow_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]
Fast slow analysis analysis proposed by John Rinzel 1 2 3.
(J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.
Examples
Tutorials please see: Dynamics Analysis (Symbolic)
- Parameters
integrals (function, functions) – The integral functions defined with brainpy.odeint or brainpy.sdeint or brainpy.ddeint, or brainpy.fdeint.
fast_vars (dict) – The fast dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.
slow_vars (dict) – The slow dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.
fixed_vars (dict) – The fixed variables. It must a fixed value with the format of {‘var1’: value}.
pars_update (dict, optional) – The parameters to update. Or, they can be treated as staitic parameters. Same with the fixed_vars, they are must fixed values with the format of {‘par1’: value}.
numerical_resolution (float, dict) – The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example, set
numerical_resolution=0.1
will generalize it to all variables and parameters; setnumerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05}
will specify the particular resolutions to variables and parameters. Moreover, you can also setnumerical_resolution={var1: np.array([...]), var2: 0.1}
to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.options (dict, optional) –
The other setting parameters, which includes:
- perturbation
float. The small perturbation used to solve the function derivatives.
- sympy_solver_timeout
float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.
- escape_sympy_solver
bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.
- lim_scale
float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to
[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]
.
References
- 1(1,2)
Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.
- 2(1,2)
Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.
- 3(1,2)
Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.
- plot_bifurcation(*args, **kwargs)[source]
Plot bifurcation.
- Parameters
show (bool) – Whether show the bifurcation figure.
- Returns
points – The bifurcation points which specifies their fixed points and corresponding stability.
- Return type
dict
- plot_limit_cycle_by_sim(*args, **kwargs)[source]
Plot limit cycles by the simulation results.
This function help users plot the limit cycles through the simulation results, in which the periodic signals will be automatically found and then treated them as the candidate of limit cycles.
- Parameters
var (str) – The target variable to found its limit cycles.
duration (int, float, tuple, list) – The simulation duration.
inputs (tuple, list) – The simulation inputs.
plot_style (dict) – The limit cycle plotting style settings.
tol (float) – The tolerance to found periodic signals.
show (bool) – Whether show the figure.
- plot_trajectory(*args, **kwargs)[source]
Plot trajectory.
This function helps users to plot specific trajectories.
- Parameters
initials (list, tuple) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example,
(a, b)
). It can also be a tuple/list of tuple to specify multiple initial values (for example,[(a1, b1), (a2, b2)]
).duration (int, float, tuple, list) – The running duration. Same with the
duration
inNeuGroup.run()
. It can be a int/float (t_end
) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)
) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]
) to specify the specific start and end simulation time for each initial value.plot_duration (tuple/list of tuple, optional) – The duration to plot. It can be a tuple with
(start, end)
. It can also be a list of tuple[(start1, end1), (start2, end2)]
to specify the plot duration for each initial value running.show (bool) – Whether show or not.
- class brainpy.analysis.symbolic.PhasePlane(model, target_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]
A tool class for phase plane analysis.
PhasePlane is used to analyze the phase portrait of 1D or 2D dynamical systems. It can also be used to analyze the phase portrait of high-dimensional system but with the fixation of other variables to preserve only one/two variables dynamical.
Examples
Tutorials please see: Dynamics Analysis (Symbolic)
- Parameters
model (DynamicalSystem, Integrator, list of Integrator, tuple of Integrator) – The neuron model which defines the differential equations.
target_vars (dict) – The target variables to analyze, with the format of {‘var1’: [var_min, var_max], ‘var2’: [var_min, var_max]}.
fixed_vars (dict, optional) – The fixed variables, which means the variables will not be updated.
pars_update (dict, optional) – The parameters in the differential equations to update.
numerical_resolution (float, dict, optional) –
The variable resolution for numerical iterative solvers. This variable will be useful in the solving of nullcline and fixed points by using the iterative optimization method.
It can be a float, which will be used as
numpy.arange(var_min, var_max, resolution)
.Or, it can be a dict, with the format of
{'var1': resolution1, 'var2': resolution2}
.Or, it can be a dict with the format of
{'var1': np.arange(x, x, x), 'var2': np.arange(x, x, x)}
.
options (dict, optional) –
The other setting parameters, which includes:
lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to
[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]
.sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.
escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.
shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.
show_shgo: bool. whether print the shgo’s value.
perturbation: float. The small perturbation used to solve the function derivative.
fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.
xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.
- plot_limit_cycle_by_sim(initials, duration, tol=0.001, show=False)[source]
Plot limit cycles according to the settings.
- Parameters
initials (list, tuple) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example,
(a, b)
). It can also be a tuple/list of tuple to specify multiple initial values (for example,[(a1, b1), (a2, b2)]
).duration (int, float, tuple, list) – The running duration. Same with the
duration
inNeuGroup.run()
. It can be a int/float (t_end
) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)
) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]
) to specify the specific start and end simulation time for each initial value.show (bool) – Whether show or not.
- plot_trajectory(initials, duration, plot_duration=None, axes='v-v', show=False)[source]
Plot trajectories according to the settings.
- Parameters
initials (list, tuple, dict) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example,
(a, b)
). It can also be a tuple/list of tuple to specify multiple initial values (for example,[(a1, b1), (a2, b2)]
).duration (int, float, tuple, list) – The running duration. Same with the
duration
inNeuGroup.run()
. It can be a int/float (t_end
) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)
) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]
) to specify the specific start and end simulation time for each initial value.plot_duration (tuple, list, optional) – The duration to plot. It can be a tuple with
(start, end)
. It can also be a list of tuple[(start1, end1), (start2, end2)]
to specify the plot duration for each initial value running.axes (str) –
The axes to plot. It can be:
- ’v-v’
Plot the trajectory in the ‘x_var’-‘y_var’ axis.
- ’t-v’
Plot the trajectory in the ‘time’-‘var’ axis.
show (bool) – Whether show or not.
Dynamics Analysis (Numeric)
|
Find fixed points by numerical optimization. |
- class brainpy.analysis.numeric.FixedPointFinder(f_cell, f_type='df', f_loss_batch=None, verbose=True, num_opt_batch=100, num_opt_max=10000, opt_setting=None, noise=0.0, tol_opt=1e-05, tol_speed=1e-05, tol_unique=0.025, tol_outlier=1.0)[source]
Find fixed points by numerical optimization.
- Parameters
f_cell (callable, function) – The function to compute the recurrent units.
f_type (str) –
The system’s type: continuous system or discrete system.
’df’: continuous derivative function, denotes this is a continuous system, or
’F’: discrete update function, denotes this is a discrete system.
f_loss_batch (callable, function) – The function to compute the loss.
verbose (bool) – Whether print the optimization progress.
num_opt_max (int) – The maximum number of optimization.
num_opt_batch (int) – Print training information during optimization every so often.
noise (float) – Gaussian noise added to fixed point candidates before optimization.
tol_opt (float) – Stop optimizing when the average value of the batch is below this value.
tol_speed (float) – Discard fixed points with squared speed larger than this value.
tol_unique (float) – Tolerance for determination of identical fixed points.
tol_outlier (float) – Any point whose closest fixed point is greater than tol is an outlier.
- compute_jacobians(points)[source]
Compute the jacobian matrices at the points.
- Parameters
points (np.ndarray, bm.JaxArray, jax.ndarray) – The fixed points with the shape of (num_point, num_dim).
- Returns
jacobians – npoints number of jacobians, np array with shape npoints x dim x dim
- Return type
bm.JaxArray
- decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=True)[source]
Compute the eigenvalues of the matrices.
- Parameters
matrices (np.ndarray, bm.JaxArray, jax.ndarray) – A 3D array with the shape of (num_matrices, dim, dim).
do_compute_lefts (bool) – Compute the left eigenvectors? Requires a pseudo-inverse call.
- Returns
decompositions – A list of dictionaries with sorted eigenvalues components: (eigenvalues, right eigenvectors, and left eigenvectors).
- Return type
list
- exclude_outliers(fixed_points, metric='euclidean')[source]
Exclude points whose closest neighbor is further than threshold.
- Parameters
fixed_points (np.ndarray) – The fixed points with the shape of (num_point, num_dim).
metric (str) – The distance metric passed to scipy.spatial.pdist. Defaults to “euclidean”
- Returns
fps_and_ids – A 2-tuple of (kept fixed points, ids of kept fixed points).
- Return type
tuple
- find_fixed_points(candidates)[source]
Top-level routine to find fixed points, keeping only valid fixed points.
This function will:
Add noise to the fixed point candidates
Optimize to find the closest fixed points / slow points
Exclude any fixed points whose fixed point loss is above threshold (‘fp_tol’)
Exclude any non-unique fixed points according to a tolerance (‘unique_tol’)
Exclude any far-away “outlier” fixed points (‘outlier_tol’)
- Parameters
rnn_fun – one-step update function as a function of hidden state
candidates – ndarray with shape npoints x ndims
hyper_params – dict of hyper parameters for fp optimization, including tolerances related to keeping fixed points
- Returns
- 4-tuple of (kept fixed points sorted with slowest points first,
fixed point losses, indicies of kept fixed points, details of optimization)
- keep_unique(fixed_points)[source]
Filter unique fixed points by choosing a representative within tolerance.
- Parameters
fixed_points (np.ndarray) – The fixed points with the shape of (num_point, num_dim).
- Returns
fps_and_ids – A 2-tuple of (kept fixed points, ids of kept fixed points).
- Return type
tuple
- optimize_fixed_points(candidates)[source]
Find fixed points via optimization.
- Parameters
candidates (jax.ndarray, JaxArray) – The array with the shape of (batch size, state dim) of hidden states of RNN to start training for fixed points.
- Returns
fps_and_losses – A tuple of (the fixed points, the optimization losses).
- Return type
tuple
Continuation Analysis
Stability Analysis
Get the stability types of 1D system. |
|
Get the stability types of 2D system. |
|
Get the stability types of 3D system. |
|
|
Stability analysis for fixed points. |
brainpy.visualization
module
Visualization toolkit.
|
Get the constrained_layout figure. |
|
Show the specified value in the given object (Neurons or Synapses.) |
|
Show the rater plot of the spikes. |
|
Animate the potentials of the neuron group. |
|
Animation of one-dimensional data. |
|
Plot style for publication. |
brainpy.tools
module
AST-to-Code
|
Decompiles an AST into Python code. |
Code Tools
|
|
|
|
|
Return all the identifiers in a given string |
|
|
|
|
|
Applies a dict of word substitutions. |
|
Check whether the function is a |
|
Get the main function _code string. |
|
|
|
New Dict
|
Python dictionaries with advanced dot notation access. |
- class brainpy.tools.DictPlus(*args, **kwargs)[source]
Python dictionaries with advanced dot notation access.
For example:
>>> d = DictPlus({'a': 10, 'b': 20}) >>> d.a 10 >>> d['a'] 10 >>> d.c # this will raise a KeyError KeyError: 'c' >>> d.c = 30 # but you can assign a value to a non-existing item >>> d.c 30
Name Checking
|
|
|
Other Tools
|
brainpy.jaxsetting
module
|
|
|
Changes platform to CPU, GPU, or TPU. |
By default, XLA considers all CPU cores as one device. |
Release notes
Version 1.1.5
API changes:
fix bugs on ndarray import in brainpy.base.function.py
convenient ‘get_param’ interface brainpy.simulation.layers
add more weight initialization methods
Doc changes:
add more examples in README
Version 1.1.4
API changes:
add
.struct_run()
in DynamicalSystemadd
numpy_array()
conversion in brainpy.math.utils moduleadd
Adagrad
,Adadelta
,RMSProp
optimizersremove setting methods in brainpy.math.jax module
remove import jax in brainpy.__init__.py and enable jax setting, including
enable_x64()
set_platform()
set_host_device_count()
enable
b=None
as no bias in brainpy.simulation.layersset int_ and float_ as default 32 bits
remove
dtype
setting in Initializer constructor
Doc changes:
add
optimizer
in “Math Foundation”add
dynamics training
docsimprove others
Version 1.1.3
fix bugs of JAX parallel API imports
fix bugs of post_slice structure construction
update docs
Version 1.1.2
add
pre2syn
andsyn2post
operatorsadd verbose and check option to
Base.load_states()
fix bugs on JIT DynamicalSystem (numpy backend)
Version 1.1.1
fix bugs on symbolic analysis: model trajectory
change absolute access in the variable saving and loading to the relative access
add UnexpectedTracerError hints in JAX transformation functions
Version 1.1.0
This package releases a new version of BrainPy.
Highlights of core changes:
math
module
support numpy backend
support JAX backend
support
jit
,vmap
andpmap
on class objects on JAX backendsupport
grad
,jacobian
,hessian
on class objects on JAX backendsupport
make_loop
,make_while
, andmake_cond
on JAX backendsupport
jit
(based on numba) on class objects on numpy backendunified numpy-like ndarray operation APIs
numpy-like random sampling APIs
FFT functions
gradient descent optimizers
activation functions
loss function
backend settings
base
module
Base
for whole Version ecosystemFunction
to wrap functionsCollector
andTensorCollector
to collect variables, integrators, nodes and others
integrators
module
class integrators for ODE numerical methods
class integrators for SDE numerical methods
simulation
module
support modular and composable programming
support multi-scale modeling
support large-scale modeling
support simulation on GPUs
fix bugs on
firing_rate()
remove
_i
inupdate()
function, replace_i
with_dt
, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)reimplement the
input_step
andmonitor_step
in a more intuitive waysupport to set dt in the single object level (i.e., single instance of DynamicSystem)
common used DNN layers
weight initializations
refine synaptic connections
Version 1.0.3
Fix bugs on
firing rate measurement
stability analysis
Version 1.0.2
This release continues to improve the user-friendliness.
Highlights of core changes:
Remove support for Numba-CUDA backend
Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).
Add the output message of the step function running error.
More powerful support for Monitoring
More powerful support for running order scheduling
Remove unsqueeze() and squeeze() operations in
brainpy.ops
Add reshape() operation in
brainpy.ops
Improve docs for numerical solvers
Improve tests for numerical solvers
Add keywords checking in ODE numerical solvers
Add more unified operations in brainpy.ops
Support “@every” in steps and monitor functions
Fix ODE solver bugs for class bounded function
Add build phase in Monitor
Version 1.0.1
Fix bugs
Version 1.0.0
NEW VERSION OF BRAINPY
Change the coding style into the object-oriented programming
Systematically improve the documentation
Version 0.3.5
Add ‘timeout’ in sympy solver in neuron dynamics analysis
Reconstruct and generalize phase plane analysis
Generalize the repeat mode of
Network
to different running duration between two runsUpdate benchmarks
Update detailed documentation
Version 0.3.1
Add a more flexible way for NeuState/SynState initialization
Fix bugs of “is_multi_return”
Add “hand_overs”, “requires” and “satisfies”.
Update documentation
Auto-transform range to numba.prange
Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models
Version 0.3.0
Computation API
Rename “brainpy.numpy” to “brainpy.backend”
Delete “pytorch”, “tensorflow” backends
Add “numba” requirement
Add GPU support
Profile setting
Delete “backend” profile setting, add “jit”
Core systems
Delete “autopepe8” requirement
Delete the format code prefix
Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”
Change the “ST” declaration out of “requires”
Add “repeat” mode run in Network
Change “vector-based” to “mode” in NeuType and SynType definition
Package installation
Remove “pypi” installation, installation now only rely on “conda”
Version 0.2.4
API changes
Fix bugs
Version 0.2.3
API changes
Add “animate_1D” in
visualization
moduleAdd “PoissonInput”, “SpikeTimeInput” and “FreqInput” in
inputs
moduleUpdate phase_portrait_analyzer.py
Models and examples
Add CANN examples
Version 0.2.2
API changes
Redesign visualization
Redesign connectivity
Update docs
Version 0.2.1
API changes
Fix bugs in numba import
Fix bugs in numpy mode with scalar model
Version 0.2.0
API changes
For computation:
numpy
,numba
For model definition:
NeuType
,SynConn
For model running:
Network
,NeuGroup
,SynConn
,Runner
For numerical integration:
integrate
,Integrator
,DiffEquation
For connectivity:
One2One
,All2All
,GridFour
,grid_four
,GridEight
,grid_eight
,GridN
,FixedPostNum
,FixedPreNum
,FixedProb
,GaussianProb
,GaussianWeight
,DOG
For visualization:
plot_value
,plot_potential
,plot_raster
,animation_potential
For measurement:
cross_correlation
,voltage_fluctuation
,raster_plot
,firing_rate
For inputs:
constant_current
,spike_current
,ramp_current
.
Models and examples
Neuron models:
HH model
,LIF model
,Izhikevich model
Synapse models:
AMPA
,GABA
,NMDA
,STP
,GapJunction
Network models:
gamma oscillation