BrainPy documentation

BrainPy is a highly flexible and extensible framework targeting on the high-performance brain modeling. Among its key ingredients, BrainPy supports:

  1. JIT compilation for functions and class objects.

  2. Numerical solvers for ODEs, SDEs and others.

  3. Dynamics simulation tools for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more.

  4. Dynamics analysis tools for differential equations, including phase plane analysis, bifurcation analysis, and linearization analysis.

  5. Seamless integration with deep learning models, but has the high speed acceleration because of JIT compilation.

  6. And more ……

Note

Comprehensive examples of BrainPy please see:

Installation

BrainPy is designed to run on across-platforms, including Windows, GNU/Linux and OSX. It only relies on Python libraries.

Installation with pip

You can install BrainPy from the pypi. To do so, use:

pip install brain-py

If you try to update the BrainPy version, you can use

pip install -U brain-py

If you want to install the pre-release version (the latest development version) of BrainPy, you can use:

pip install --pre brain-py

Installation from source

If you decide not to use conda or pip, you can install BrainPy from GitHub, or OpenI.

To do so, use:

pip install git+https://github.com/PKU-NIP-Lab/BrainPy

# or

pip install git+https://git.openi.org.cn/OpenI/BrainPy

Package Dependency

In order to make BrainPy work normally, users should install several dependent Python packages.

NumPy & Matplotlib

The basic function of BrainPy only relies on NumPy and Matplotlib. Install these two packages is very easy, just using pip or conda:

pip install numpy matplotlib
# or
conda install numpy matplotlib

JAX

We highly recommend you to install JAX. JAX is a high-performance JIT compiler which enables users run Python code on CPU, GPU, or TPU devices. Most functionalities of BrainPy is based on JAX.

Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of JAX for Linux and macOS systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html .

To install a CPU-only version of JAX, you can run

pip install --upgrade "jax[cpu]"

If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Alternatively, you can download the preferred release “.whl” file, and install it via pip:

pip install xxxx.whl

For Windows users, JAX can be installed by the following methods:

Method 1: For Windows 10+ system, you can Windows Subsystem for Linux (WSL). The installation guide can be found in WSL Installation Guide for Windows 10. Then, you can install JAX in WSL just like the installation step in Linux.

Method 2: There are several community supported Windows build for jax, please refer to the github link for more details: https://github.com/cloudhan/jax-windows-builder . Simply speaking, you can run:

# for only CPU
pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html

# for GPU support
pip install <downloaded jaxlib>

Method 3: You can also build JAX from source.

Numba

Numba is also an excellent JIT compiler, which can accelerate your Python codes to approach the speeds of C or FORTRAN. Numba works best with NumPy. Many BrainPy modules rely on Numba for speed acceleration, such like connectivity, simulation, analysis, measurements, etc. Numba is also a suitable framework for the computation of sparse synaptic connections commonly used in the computational neuroscience project.

Numba is a cross-platform package which can be installed on Windows, Linux, and macOS. Install Numba is a piece of cake. You just need type the following commands in you terminal:

pip install numba
# or
conda install numba

SymPy

In BrainPy, several modules need the symbolic inference by SymPy. For example, Exponential Euler numerical solver needs SymPy to compute the linear part of your defined Python codes, phase plane and bifurcation analysis in dynamics analysis module needs symbolic computation from SymPy. Therefore, we highly recommend you to install sympy, just typing

pip install sympy
# or
conda install sympy

JIT Compilation

@Chaoming Wang

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it’s hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on class in Python will make your coding more readable, controlable, flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

Here, in BrainPy, we provide JIT compilation interface for class objects, built on the top of JAX and Numba. In this section, we will talk about this.

import brainpy as bp
import brainpy.math as bm

JIT in Numba and JAX

Numba is specialized to optimize your native NumPy codes, including NumPy arrays, loops and condition controls, etc. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can just-in-time compile your native Python loops (for or while syntaxs) and condition controls (if ... else ...). This means that it supports your intutive Python programming.

However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn’t support one code runs on multiple devices. Same code cannot run on GPU targets.

JAX is a rising-star JIT compiler in Python scientific computing. It uses XLA to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance.

However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on WSL or compile JAX from source. Further, the coding in JAX is not very intutive. For example,

  • Doesn’t support in-place mutating updates of arrays, like x[i] += y, instead you should use x = jax.ops.index_update(x, i, y)

  • Doesn’t support JIT compilation of your native loops and conditions, like

arr = np.zeros(5)
for i in range(arr.shape[0]):
    arr[i] += 2.
    if i % 2 == 0:
        arr[i] += 1.

instead you should use

arr = np.zeros(5)
def loop_body(i, acc_arr):
    arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
    return jax.lax.cond(i % 2 == 0, 
                        arr1,
                        lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
                        arr1,
                        lambda arr1: arr1)
arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)

What’s more, both frameworks have poor support on class objects.

JIT compilation in BrainPy

In order to obtain an intutive, flexible and high-performance framework for brain modeling, in BrainPy, we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it).

Specifically, we provide BrainPy math module for

  • flexible switch between NumPy (Numba) and JAX backends

  • unified numpy-like array operations

  • unified ndarray data structure which supports in-place update

  • unified random APIs

  • powerful jit() compilation which supports functions and class objects both

Backend Switch

To switch different backend, you can use:

# switch to NumPy backend
bm.use_backend('numpy')

bm.get_backend_name()
'numpy'
# switch to JAX backend
bm.use_backend('jax')

bm.get_backend_name()
'jax'

In BrainPy, “numpy” and “jax” backends are interchangeable. Both backends nearly have the same APIs, and same codes can run on both backends.

Math Operations

The APIs in brainpy.math module in each backend is much similar to APIs in original numpy. The detailed comparison please see the Comparison Table.

For example, the array creation functions,

bm.zeros((10, 3))
JaxArray(DeviceArray([[0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.]], dtype=float32))
bm.arange(10)
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
x = bm.array([[1,2], [3,4]])

x
JaxArray(DeviceArray([[1, 2],
                      [3, 4]], dtype=int32))

The array manipulation functions:

bm.max(x)
DeviceArray(4, dtype=int32)
bm.repeat(x, 2)
JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))
bm.repeat(x, 2, axis=1)
JaxArray(DeviceArray([[1, 1, 2, 2],
                      [3, 3, 4, 4]], dtype=int32))

The random numbers generation functions:

bm.random.random((3, 5))
JaxArray(DeviceArray([[0.77751863, 0.45367956, 0.09705102, 0.35475922, 0.6678729 ],
                      [0.81603515, 0.83256996, 0.29231536, 0.8294617 , 0.41171193],
                      [0.37354684, 0.8089644 , 0.4921714 , 0.93098676, 0.4895897 ]],            dtype=float32))
y = bm.random.normal(loc=0.0, scale=2.0, size=(2, 5))

y
JaxArray(DeviceArray([[ 0.24033605,  1.6801527 ,  1.6716577 ,  0.19926837,
                       -2.3778987 ],
                      [ 0.58184516,  1.3625289 , -2.3364332 ,  2.082281  ,
                        0.23409791]], dtype=float32))

The linear algebra functions:

bm.dot(x, y)
JaxArray(DeviceArray([[ 1.4040264,  4.4052105, -3.0012088,  4.3638306, -1.9097029],
                      [ 3.0483887, 10.490574 , -4.3307595,  8.926929 , -6.1973042]],            dtype=float32))
bm.linalg.eig(x)
(JaxArray(DeviceArray([-0.37228107+0.j,  5.3722816 +0.j], dtype=complex64)),
 JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
                       [ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))

The discrete fourier transform functions:

bm.fft.fft(bm.exp(2j * bm.pi * bm.arange(8) / 8))
JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j,  8.0000000e+00+4.8023384e-07j,
                      -3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
                      -3.8941437e-07-2.0663207e-07j,  2.3841858e-07-1.9411573e-07j,
                       3.8941437e-07-2.0663207e-07j,  1.6858739e-07+3.1786506e-08j],            dtype=complex64))
bm.fft.ifft(bm.array([0, 4, 0, 0]))
JaxArray(DeviceArray([ 1.+0.j,  0.+1.j, -1.+0.j,  0.-1.j], dtype=complex64))

The full list of API implementation please see the Comparison Table.

JIT for Functions

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform your codes into machine codes.

Take the pure functions as the example. Here we try to implement a function of Gaussian Error Linear Unit,

bm.use_backend('jax')


def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let’s first try the running speed without JIT.

# jax backend, without JIT

x = bm.random.random(100000)
%timeit gelu(x)
332 µs ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The function after JIT can significantly speed up.

# jax backend, with JIT

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66.6 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

After siwtching to the ‘numpy’ backend, the result also applies.

bm.use_backend('numpy')
# numpy backend, without JIT

x = bm.random.random(100000)
%timeit gelu(x)
3.52 ms ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# numpy backend, with JIT

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
1.91 ms ± 55.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JIT for Objects

However, in BrainPy, JIT compilation can be carried on the class objects. Specifically, any instance of brainpy.Base object can be just-in-time compiled into machine codes.

Let’s try a simple example, which trains a Logistic regression classifier.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000

In this example, we want to train the model weights (self.w) again and again. So, we mark the weights as bm.Variable.

import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):
        model(points, labels)

    print(f'{name} used time {time.time() - t0} s')

Let’s try to benchmark the ‘numpy’ backend.

bm.use_backend('numpy')

points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# numpy backend, without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 18.891679763793945 s
# numpy backend, with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 12.473472833633423 s

Same results can be obtained in ‘jax’ backend.

bm.use_backend('jax')

points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# jax backend, without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 6.245944976806641 s
# jax backend, with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 4.174307346343994 s

As you can see, the JIT function showes its speed advantage.

However, what’s worth noting here is that:

  1. The dynamically changed variable (weight w) is marked as a bm.Variable (in __init__() function).

  2. The variable w is in-place updated with [:] indexing (in __call__() function).

The above two things are all constraints in BrainPy for the JIT compilation of class objects. Other operations and coding styles are the same with the normal object-oriented programming in Python.

Mechanism of JIT in NumPy backend
bm.use_backend('numpy')

So, why must we in-place update the dynamically changed variables?

  • First of all, in the compilation phase, a self. accessed variable which is not an instance of bm.Variable will be compiled as a static constant. For example, self.a = 1. will be compiled as a constant 1.. If you try to change the value of self.a, it will not work.

class Demo1(bp.Base):
    def __init__(self):
        super(Demo1, self).__init__()
        
        self.a = 1.
    
    def update(self, b):
        self.a = b
        

d1 = Demo1()
bm.jit(d1.update)(2.)
print(d1.a)
1.0
  • Second, all the variables you want to change during the function call must be labeled as bm.Variable. Then during the JIT compilation period, these variables will be recompiled as arguments of the jitted functions.

class Demo2(bp.Base):
    def __init__(self):
        super(Demo2, self).__init__()
        
        self.a = bm.Variable(1.)
    
    def update(self, b):
        self.a = b
        

bm.jit(Demo2().update, show_code=True)
The recompiled function:
-------------------------

def update(b, Demo20_a=None):
    Demo20_a = b


The namespace of the above function:
{}

The recompiled function:
-------------------------
def new_update(b):
  update(b, 
		Demo20_a=Demo20.a.value,)

The namespace of the above function:
{'Demo20': <__main__.Demo2 object at 0x7fa9b055e2e0>,
 'update': CPUDispatcher(<function update at 0x7fa9d01ff160>)}
<function new_update(b)>

The original Demo2.update function is recompiled as update() function, with the dynamical variable a compiled as an argument Demo20_a. Then, during the functional call (in the new_update() function), Demo20.a.value is passed to Demo20_a for the jitted update() function.

  • Third, as you can notice in the above source code of the recompiled function, the recompiled variable Demo20_a does not return. This means once the function finished its running, the computed value will disappear. Therefore, the dynamically changed variables must be in-place updated to hold their updated values.

class Demo3(bp.Base):
    def __init__(self):
        super(Demo3, self).__init__()
        
        self.a = bm.Variable(1.)
    
    def update(self, b):
        self.a[...] = b
        

d3 = Demo3()
bm.jit(d3.update)(2.)
d3.a
Variable(2.)

The above simple demonstrations illustrate the core mechanism of the JIT compilation in NumPy backend. bm.jit() in NumPy backend can recursively compile your class objects. So, please try your models, and run it under the JIT accelerations.

However, the mechanism of JIT compilation of JAX backend is quite different. We will detail this in the upcoming tutorials.

In-place operators

In the next, it is important to answer: what are in-place operators?

v = bm.arange(10)

id(v)
140366785129408

Actually, in-place operators include the following operations:

  1. Indexing and slicing. Like (More details please refer to Array Objects Indexing)

  • Index: v[i] = a

  • Slice: v[i:j] = b

  • Slice the specific values: v[[1, 3]] = c

  • Slice all values, v[:] = d, v[...] = e

v[0] = 1

id(v)
140366785129408
v[1: 2] = 1

id(v)
140366785129408
v[[1, 3]] = 2

id(v)
140366785129408
v[:] = 0

id(v)
140366785129408
v[...] = bm.arange(10)

id(v)
140366785129408
  1. Augmented assignment. All augmented assignment are in-place operations, which include

  • += (add)

  • -= (subtract)

  • /= (divide)

  • *= (multiply)

  • //= (floor divide)

  • %= (modulo)

  • **= (power)

  • &= (and)

  • |= (or)

  • ^= (xor)

  • <<= (left shift)

  • >>= (right shift)

v += 1

id(v)
140366785129408
v *= 2

id(v)
140366785129408
v |= bm.random.randint(0, 2, 10)

id (v)
140366785129408
v **= 2.

id(v)
140366785129408
v >>= 2

id(v)
140366785129408

More advanced usage please see our forthcoming tutorials.

Dynamics Introduction

@Chaoming Wang

What I cannot create, I do not understand. — Richard Feynman

Brain is a complex dynamical system. In order to simulate it, we provide brainpy.DynamicalSystem. brainpy.DynamicalSystem can be used to define any brain objects which have dynamics. Various children classes are implemented to model these brain elements, such like: brainpy.Channel for neuron channels, brainpy.NeuGroup for neuron groups, brainpy.TwoEndConn for synaptic connections, brainpy.Network for networks, etc. Arbitrary composition of these objects is also an instance of brainpy.DynamicalSystem. Therefore, brainpy.DynamicalSystem is the universal language to define dynamical models in BrainPy.

import brainpy as bp
import brainpy.math as bm

brainpy.DynamicalSystem

In this section, let’s try to understand the mechanism and the function of brainpy.DynamicalSystem.

What is DynamicalSystem?

First, what can be defined as DynamicalSystem?

Intuitively, a dynamical system is a system which has the time-dependent state.

Mathematically, it can be expressed as

$$ \dot{X} = f(X, t) $$

where $X$ is the state of the system, $t$ is the time, and $f$ is a function describes the time dependence of the system state.

Alternatively, the evolution of the system along the time can be given by

$$ X(t+dt) = F\left(X(t), t, dt\right) $$

where $dt$ is the time step, and $F$ is the evolution rule to update the system’s state.

Accordingly, in BrainPy, any subclass of brainpy.DynamicalSystem must implement this updating rule in the update function (def update(self, _t, _dt)). One dynamical system may have multiple updating rules, therefore, users can define multiple update functions. All updating functions are wrapped into an inner data structure self.steps (a Python dictionary specifies the name and the function of updating rules). Let’s take a look.

class FitzHughNagumoModel(bp.DynamicalSystem):
    def __init__(self, a=0.8, b=0.7, tau=12.5, **kwargs):
        super(FitzHughNagumoModel, self).__init__(**kwargs)
        
        # parameters
        self.a = a
        self.b = b
        self.tau = tau
        
        # variables
        self.v = bm.Variable([0.])
        self.w = bm.Variable([0.])
        self.I = bm.Variable([0.])
        
    def update(self, _t, _dt):
        # _t : the current time, the system keyword 
        # _dt : the time step, the system keyword 
        
        self.w += (self.v + self.a - self.b * self.w) / self.tau * _dt
        self.v += (self.v - self.v ** 3 / 3 - self.w + self.I) * _dt
        self.I[:] = 0.

Here, we have defined a dynamical system called FitzHugh–Nagumo neuron model, whose dynamics is given by:

$$ {\dot {v}}=v-{\frac {v^{3}}{3}}-w+I, \ \tau {\dot {w}}=v+a-bw. $$

By using the Euler method, this system can be updated by the following rule:

$$ \begin{aligned} v(t+dt) &= v(t) + [v(t)-{v(t)^{3}/3}-w(t)+RI] * dt, \ w(t + dt) &= w(t) + [v(t) + a - b w(t)] * dt. \end{aligned} $$

We can inspect all update functions in the model by xxx.steps.

fnh = FitzHughNagumoModel()

fnh.steps  # all update functions
{'update': <bound method FitzHughNagumoModel.update of <__main__.FitzHughNagumoModel object at 0x7f7f4c7ceaf0>>}

Why to use DynamicalSystem?

So, why should I define my dynamical system as brainpy.DynamicalSystem?

There are several benefits.

  • brainpy.DynamicalSystem has a systematic naming system.

First, every instance of DynamicalSystem has its unique name.

fnh.name  # name for "fnh" instance
'FitzHughNagumoModel0'
# every instance has its unique name

for _ in range(5):
    print(FitzHughNagumoModel().name)
FitzHughNagumoModel1
FitzHughNagumoModel2
FitzHughNagumoModel3
FitzHughNagumoModel4
FitzHughNagumoModel5
# the model name can be specified by yourself

fnh2 = FitzHughNagumoModel(name='X')

fnh2.name
'X'
# same name will cause error

try:
    FitzHughNagumoModel(name='X')
except bp.errors.UniqueNameError as e:
    print(e)
In BrainPy, each object should have a unique name. However, we detect that <__main__.FitzHughNagumoModel object at 0x7f7f4c7f3850> has a used name "X".

Second, variables, children nodes, etc. inside an instance can be easily accessed by the absolute or relative path.

# All variables can be acessed by 
# 1). the absolute path

fnh2.vars()
{'X.I': Variable([0.]), 'X.v': Variable([0.]), 'X.w': Variable([0.])}
# 2). or, the relative path

fnh2.vars(method='relative')
{'I': Variable([0.]), 'v': Variable([0.]), 'w': Variable([0.])}

If we wrap many instances into a container: brainpy.Network, variables and nodes can also be accessed by absolute or relative path.

fnh_net = bp.Network(f1=fnh, f2=fnh2)
# absolute access of variables

fnh_net.vars()
{'FitzHughNagumoModel0.I': Variable([0.]),
 'FitzHughNagumoModel0.v': Variable([0.]),
 'FitzHughNagumoModel0.w': Variable([0.]),
 'X.I': Variable([0.]),
 'X.v': Variable([0.]),
 'X.w': Variable([0.])}
# relative access of variables

fnh_net.vars(method='relative')
{'f1.I': Variable([0.]),
 'f1.v': Variable([0.]),
 'f1.w': Variable([0.]),
 'f2.I': Variable([0.]),
 'f2.v': Variable([0.]),
 'f2.w': Variable([0.])}
# absolute access of nodes

fnh_net.nodes()
{'FitzHughNagumoModel0': <__main__.FitzHughNagumoModel at 0x7f7f4c7ceaf0>,
 'X': <__main__.FitzHughNagumoModel at 0x7f7f4c7f3700>,
 'Network0': <brainpy.simulation.brainobjects.network.Network at 0x7f7f4c7f3340>}
# relative access of nodes

fnh_net.nodes(method='relative')
{'': <brainpy.simulation.brainobjects.network.Network at 0x7f7f4c7f3340>,
 'f1': <__main__.FitzHughNagumoModel at 0x7f7f4c7ceaf0>,
 'f2': <__main__.FitzHughNagumoModel at 0x7f7f4c7f3700>}
  • Automatic monitors. Any instance of brainpy.DynamicalSystem can call .run(duration). During running, a brainpy.Monitor inside the dynamical system (xxx.mon) can be used to automatically monitor the history values of the interested variables. Details please see the tutorial of Monitors and Inputs.

# in "fnh3" instance, we try to monitor "v", "w", and "I" variables
fnh3 = FitzHughNagumoModel(monitors=['v', 'w', 'I'])

# in "fnh4" instance, we only monitor "v" variable
fnh4 = FitzHughNagumoModel(monitors=['v'], name='Y')
  • Convenient input operations. During the model running, users can specify the inputs for each model component, with the format of (target, value, [type, operation]) (the details please see the tutorial of [Monitors and Inputs]).(../tutorial_simulation/monitors_and_inputs.ipynb).

    • The target is the variable accessed by the absolute or relative path. Absolute path access will be very useful in a huge network model.

    • The default input type is “fix”, means the value must be a constant scalar or array over time. “iter” type of input is also allowed, which means the value can be an iterable objects (arrays, or iterable functions, etc.).

    • The default operation is +, which means the input value will be added to the target. Allowed operations include +, -, *, /, and =.

%matplotlib inline

import matplotlib.pyplot as plt
bm.set_dt(dt=0.01)

fnh3.run(duration=100, 
         # relative path to access variable 'I'
         inputs=('I', 1.5))

plt.plot(fnh3.mon.ts, fnh3.mon.v, label='v')
plt.plot(fnh3.mon.ts, fnh3.mon.w, label='w')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c761e80>
_images/7beace84e8c5d0f3e7d8952dcce9017c432494f694d411cb6c3b704bb1f403a5.png
inputs = bm.linspace(1., 2., 10000)

fnh4.run(duration=100, 
         inputs=('Y.I',     #  specify 'target' by the absolute path access
                 inputs,    #  specify 'value' with an iterable array
                 'iter'))   #  "iter" input 'type' must be explicitly specified

plt.plot(fnh4.mon.ts, fnh4.mon.v, label='v')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c7f3b80>
_images/3abab0b01f6e76d572f29cc1f66764cb903e96395e0aa2da73453049e759ea2a.png
def inputs():
    for i in range(10000): 
        yield 1.5

fnh4.run(duration=100, 
         inputs=('Y.I',     # specify 'target' by the absolute path access
                 inputs(),    # specify 'value' with an iterable function
                 'iter'))   # "iter" input 'type' must be explicitly specified

plt.plot(fnh4.mon.ts, fnh4.mon.v, label='v')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f4c082d60>
_images/732706026c587aec5693f4f4582fb8a10b2356a6e2d14ff1bbf7fe0d666e3a37.png
  • brainpy.DynamicalSystem is a subclass of brainpy.Base, therefore, any instance of brainpy.DynamicalSystem can be just-in-time compiled into efficient machine codes targeting on CPUs, GPUs, or TPUs.

fnh3_jit = bm.jit(fnh3)

fnh3_jit.run(duration=100, inputs=('I', 1.5))

plt.plot(fnh3_jit.mon.ts, fnh3_jit.mon.v, label='v')
plt.plot(fnh3_jit.mon.ts, fnh3_jit.mon.w, label='w')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f2c595b50>
_images/e94740ef862b5f0feb51aace1557bcdb3f22081dceef031cd90c56328a0fbd77.png
  • brainpy.DynamicalSystem can be combined arbitrarily. Any composed system can also benefit from the above convenient interfaces.

# compose two FitzHughNagumoModel instances into a Network
net2 = bp.Network(f1=fnh3, f2=fnh4, monitors=['f1.v', 'Y.v'])

net2.run(100, inputs=[
    ('f1.I', 1.5), # relative access variable "I" in 'fnh3'
    ('Y.I', 1.0), # absolute access variable "I" in 'fnh4'
])

plt.plot(net2.mon.ts, net2.mon['f1.v'], label='v1')
plt.plot(net2.mon.ts, net2.mon['Y.v'], label='v2')
plt.legend()
<matplotlib.legend.Legend at 0x7f7f2c2c6ee0>
_images/8582cab39d333f6428868a49eb56505264112c000a5f11b527dccc5a703e9f89.png

In next sections, we will illustrate how to define common brain objects (specifically, the neuron, the synapse and the network) by subclasses of brainpy.DynamicalSystem.

brainpy.NeuGroup

brainpy.NeuGroup is used for neuron group modeling. User-defined neuron group models should inherit from the brainpy.NeuGroup. Let’s take the leaky integrate-and-fire (LIF) model as the illustrated example.

LIF neuron model

The formal equations of a LIF model is given by:

$$ \begin{aligned} \tau_m \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \quad\quad (1) \ \text{after} , V(t) \gt V_{th}, V(t) =V_{rest} , \text{last} , \tau_{ref} , \text{ms} \quad\quad (2) \end{aligned} $$

where $V$ is the membrane potential, $V_{rest}$ is the rest membrane potential, $V_{th}$ is the spike threshold, $\tau_m$ is the time constant, $\tau_{ref}$ is the refractory time period, and $I$ is the time-variant synaptic inputs.

The above two equations mean that: when the membrane potential $V$ is below $V_{th}$, the model integrates $V$ with the equation (1); once $V > V_{th}$, according to equation (2), we will reset the membrane potential to $V_{rest}$, and the model enters into the refractory period which lasts $\tau_{ref}$ ms. In the refractory period, the membrane potential $V$ will no longer change.

Let’s start to code this LIF neuron model. First, we will define the following items to store the neuron state:

  • V: The membrane potential.

  • input: The synaptic input.

  • spike: Whether produce a spike.

  • refractory: Whether the neuron is in the refractory state.

  • t_last_spike: To record the last spike time.

Based on these states, the updating logic of LIF model from the current time $t$ to the next time $t+dt$ will be coded as:

class LIF(bp.NeuGroup):
  def __init__(self, size, t_refractory=1., V_rest=0., V_reset=-5.,
               V_th=20., R=1., tau=10., **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)

    # parameters
    self.V_rest = V_rest
    self.V_reset = V_reset
    self.V_th = V_th
    self.R = R
    self.tau = tau
    self.t_refractory = t_refractory

    # variables
    self.V = bm.Variable(bm.random.randn(self.num) * 5. + V_reset)
    self.input = bm.Variable(bm.zeros(self.num))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
    self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    
    # functions
    self.integral = bp.odeint(f=self.derivative, method='exponential_euler')

  def derivative(self, V, t, Iext):
    dvdt = (- (V - self.V_rest) + self.R * Iext) / self.tau
    return dvdt

  def update(self, _t, _dt):
    for i in range(self.num):
      if _t - self.t_last_spike[i] <= self.t_refractory:
        self.refractory[i] = True
        self.spike[i] = False
      else:
        V = self.integral(self.V[i], _t, self.input[i])
        if V >= self.V_th:
          self.V[i] = self.V_reset
          self.t_last_spike[i] = _t
          self.spike[i] = True
          self.refractory[i] = True
        else:
          self.V[i] = V
          self.spike[i] = False
          self.refractory[i] = False
      self.input[i] = 0.

That’s all, we have coded a LIF neuron model. Note, here we define equation (1) by brainpy.odeint as an ODEIntegrator. We will illustrate how to define ODE numerical integration in the Numerical Solvers for ODEs tutorial.

Each NeuGroup has a powerful function: .run(). In this function, it receives the following arguments:

  • duration: Specify the simulation duration. It can be a tuple with (start time, end time), or a int to specify the duration length (then the default start time is 0).

  • inputs: Specify the inputs for each model component. With the format of (target, value, [type, operation]). Details please see the tutorial of Monitors and Inputs.

  • report: a float to specify the progress percent to report. “0” (default) means doesn’t report running progress.

Now, let’s run the defined model.

group = LIF(100, monitors=['V'])
group.run(duration=200., inputs=('input', 26.), report=0.5)
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
Compilation used 0.0004 s.
Start running ...
Run 50.0% used 3.319 s.
Run 100.0% used 6.479 s.
Simulation is done in 6.480 s.
_images/b594c4d6bbed11496f930f3d8169e04c8d9b91174008c9f7b2b5065fdd8a327c.png
group.run(duration=(200, 400.), report=0.2)
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
Compilation used 0.0003 s.
Start running ...
Run 20.0% used 1.364 s.
Run 40.0% used 2.711 s.
Run 60.0% used 4.024 s.
Run 80.0% used 5.299 s.
Run 100.0% used 6.589 s.
Simulation is done in 6.590 s.
_images/5e1d11e7263074f68e0d8d444dbe16d841ce6ed38198e2505372764a2213c515.png

In the model definition, BrainPy endows you with the fully data/logic flow control. You can define models with any data you need and any logic you want. There are little limitations/constrains on your customization.

  1. you should “super()” initialize the brainpy.NeuGroup with the keyword of the group size.

  2. you should define the update function.

brainpy.TwoEndConn

For synaptic computations, BrainPy provides brainpy.TwoEndConn to help you construct the connections between pre-synaptic and post-synaptic neuron groups, and provides brainpy.connect.TwoEndConnector for synaptic projections between pre- and post-synaptic groups.

brainpy.TwoEndConn can help to construct automatic delay in synaptic computations. The modeling of synapses usually includes a delay time (typically 0.3–0.5 ms) required for a neurotransmitter to be released from a presynaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane. BrainPy provides register_constant_dely() for automatic state delay.

brainpy.connect.TwoEndConnector provides convenient interface for connectivity structure construction. Various synaptic structures, like pre_ids, post_ids, conn_mat, pre2post, post2pre, pre2syn, post2syn, pre_slice, and post_slice can be constructed. Users just need to require such data structures by calling connector.require('pre_ids', 'post_ids', ...). We will detail this function in Efficient Synaptic Connections.

Here, let’s illustrate how to use brainpy.TwoEndConn with the Exponential synapse model.

Exponential synapse model

Exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant $\tau_{decay}$. Its dynamics is given by:

$$ \frac{d s}{d t} = -\frac{s}{\tau_{decay}}+\sum_{k} \delta(t-D-t^{k}) $$

where $s$ is the synaptic state, $t^{k}$ is the spike time of the pre-synaptic neuron, and $D$ is the synaptic delay.

Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form

$$ I_{syn}(t) = g_{max} s \left( V(t)-E \right) $$

where $E$ is the reversal potential of the synapse, $V$ is the post-synaptic membrane potential, $g_{max}$ is the maximum synaptic conductance.

So, let’s try to implement this synapse model.

class Exponential(bp.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., **kwargs):
    super(Exponential, self).__init__(pre=pre, post=post, conn=conn, **kwargs)

    # parameters
    self.g_max = g_max
    self.E = E
    self.tau = tau
    self.delay = delay

    # connections
    self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
    self.num = len(self.pre_ids)
    
    # variables
    self.s = bm.Variable(bm.zeros(self.num))
    self.pre_spike = self.register_constant_delay('ps', size=self.pre.num, delay=delay)
    
    # functions
    self.integral = bp.odeint(self.derivative, method='exponential_euler')

  def derivative(self, s, t):
    dsdt = - s / self.tau
    return dsdt

  def update(self, _t, _dt):
    # P1: push the pre-synaptic spikes into the delay
    self.pre_spike.push(self.pre.spike)
    
    # P2: pull the delayed pre-synaptic spikes
    delayed_pre_spike = self.pre_spike.pull()
    
    # P3: update the synatic state
    self.s[:] = self.integral(self.s, _t)
    
    for syn_i in range(self.num):
      pre_i, post_i = self.pre_ids[syn_i], self.post_ids[syn_i]
    
      # P4: whether pre-synaptic neuron generates a spike
      if delayed_pre_spike[pre_i]:
        self.s[syn_i] += 1.
      
      # P5: output the synapse current onto the post-synaptic neuron
      self.post.input[post_i] += self.g_max * self.s[syn_i] * (self.E - self.post.V[post_i])

Here, we create a synaptic model by using the synaptic structures of pre_ids and post_ids , looks like this:

The pre-synaptic neuron index (pre ids) is shown in the green color. The post-synaptic neuron index (post ids) is shown in the red color. Each pair of (pre id, post id) denotes a synapse between two neuron groups. Each synapse connection also has a unique index, called the synapse index, which is shown in the third row (syn ids).

brainpy.Network

In above, we have illustrated how to define neurons by brainpy.NeuGroup and synapses by brainpy.TwoEndConn. In the next, we talk about how to create a network by using brainpy.Network.

E/I balanced network

Here, we try to create a E/I balanced network according to the reference [1].

This EI network has 4000 leaky integrate-and-fire neurons. Each integrate-and-fire neuron is characterized by a time constant, $\tau$ = 20 ms, and a resting membrane potential, $V_{rest}$ = -60 mV. Whenever the membrane potential crosses a spiking threshold of -50 mV, an action potential is generated and the membrane potential is reset to the resting potential, where it remains clamped for a 5 ms refractory period.

num_exc = 3200
num_inh = 800

E = LIF(num_exc, tau=20, V_th=-50, V_rest=-60, V_reset=-60, t_refractory=5., monitors=['spike'])
I = LIF(num_inh, tau=20, V_th=-50, V_rest=-60, V_reset=-60, t_refractory=5.)
E.V[:] = bm.random.randn(num_exc) * 5. - 55.
I.V[:] = bm.random.randn(num_inh) * 5. - 55.

The ratio of the excitatory and inhibitory neurons are 4:1. The neurons connect to each other randomly with a connection probability of 2%.

The kinetics of the synapse is governed by the exponential synapse model shown above. Specifically, synaptic time constants $\tau_e$ = 5 ms for excitatory synapses and $\tau_i$ = 10 ms for inhibitory synapse. The maximum synaptic conductance is $0.6$ for the excitatory synapse and $6.7$ for the inhibitory synapse. Reversal potentials are $E_e$ = 0 mV and $E_i$ = -80 mV.

E2E = Exponential(E, E, bp.connect.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5)
E2I = Exponential(E, I, bp.connect.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5)
I2E = Exponential(I, E, bp.connect.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10)
I2I = Exponential(I, I, bp.connect.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10)

After this, we can create a network to wrap these object together.

net = bp.Network(E2E, E2I, I2I, I2E, E=E, I=I)
net = bm.jit(net)
net.run(100., inputs=[('E.input', 20.), ('I.input', 20.)], report=0.1)
bp.visualize.raster_plot(E.mon.ts, E.mon.spike, show=True)
Compilation used 4.8391 s.
Start running ...
Run 10.0% used 1.485 s.
Run 20.0% used 2.806 s.
Run 30.0% used 4.120 s.
Run 40.0% used 5.506 s.
Run 50.0% used 6.886 s.
Run 60.0% used 8.214 s.
Run 70.0% used 9.513 s.
Run 80.0% used 10.834 s.
Run 90.0% used 12.197 s.
Run 100.0% used 13.507 s.
Simulation is done in 13.507 s.
_images/b2b1da65e317167ece9aa68425d5e955be99623c766e1404fe426f435a9b4456.png

References:

[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98

Tensors

In this section, we are going to understand:

  • what is tensor?

  • how to create tensor?

  • what operations are supported for a tensor?

import brainpy.math as bm

What is tensor?

A tensor is a homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. The dimensions of an array are called axes.

In the following picture, the 1D array ([7, 2, 9, 10]) only has one axis. That axis has 4 elements in it, so we say it has a shape of (4,).

While, the 2D array

[[5.2, 3.0, 4.5], 
 [9.1, 0.1, 0.3]]

has 2 axes. The first axis has a length of 2, the second axis has a length of 3. So, we say it has a shape of (2, 3).

Similarly, the 3D array has 3 axes, with dimensions in each axis is (4, 3, 2).

Each tensor has several important attributes:

  • .ndim: the number of axes (dimensions) of the tensor.

  • .shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

  • .size: the total number of elements of the tensor. This is equal to the product of the elements of shape.

  • .dtype: an object describing the type of the elements in the tensor. One can create or specify dtype’s using standard Python types.

In ‘numpy’ backend, the tensor is exactly the same as the tensor in NumPy. For example:

bm.use_backend('numpy')

a = bm.arange(15).reshape((3, 5))

a
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int64')

However, in ‘jax’ backend, we wrap the original jax.numpy.ndarray, and create a new data structure JaxArray. However, the attributes and operations are the same with the NumPy tensors. For example:

bm.use_backend('jax')

a = bm.arange(15).reshape((3, 5))

a
JaxArray(DeviceArray([[ 0,  1,  2,  3,  4],
                      [ 5,  6,  7,  8,  9],
                      [10, 11, 12, 13, 14]], dtype=int32))
a.shape
(3, 5)
a.ndim
2
a.dtype
dtype('int32')

How to create tensor?

There are several ways to create tensors. Methods for tensor creation are same under “numpy” and “jax” backends.

array(), zeros() and ones()

The basic method is to convert Python sequences into tensors by bm.array(). For example:

bm.array([2, 3, 4])
JaxArray(DeviceArray([2, 3, 4], dtype=int32))
bm.array([(1.5, 2, 3), (4, 5, 6)])
JaxArray(DeviceArray([[1.5, 2. , 3. ],
                      [4. , 5. , 6. ]], dtype=float32))

Often, the elements of an array are originally unknown, but its size is known. Therefore, you can use placeholder functions to create tensors, like:

# "bm.zeros()" creates an array full of zeros

bm.zeros((3, 4))
JaxArray(DeviceArray([[0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.]], dtype=float32))
# "bm.ones()" creates an array full of ones

bm.ones((3, 4))
JaxArray(DeviceArray([[1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.]], dtype=float32))

linspace() and arange()

Another two commonly used 1D array creation functions are bm.linspace() and bm.arange().

bm.arange() creates arrays with regularly incrementing values. It receives “start”, “end”, and “step” settings.

# if only one argument "A" are provided, the function will  
# recognize the "start = 0", "end = A", and "step = 1" .

bm.arange(10)  
JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))
# if two argument "A, B" are provided, the function will  
# recognize the "start = A", "end = B", and "step = 1" .

bm.arange(2, 10, dtype=float)
JaxArray(DeviceArray([2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32))
# if three argument "A, B, C" are provided, the function will  
# recognize the "start = A", "end = B", and "step = C" .

bm.arange(2, 3, 0.1)
JaxArray(DeviceArray([2.       , 2.1      , 2.1999998, 2.2999997, 2.3999996,
                      2.4999995, 2.5999994, 2.6999993, 2.7999992, 2.8999991],            dtype=float32))

Due to the finite floating point precision, it is generally not possible to predict the number of elements obtained by bm.arange(). For this reason, it is usually better to use the function bm.linspace() that receives “start”, “end”, and “num” settings.

bm.linspace(2, 3, 10)
JaxArray(DeviceArray([2.       , 2.1111112, 2.2222223, 2.3333333, 2.4444447,
                      2.5555556, 2.6666665, 2.777778 , 2.8888888, 3.       ],            dtype=float32))

Random sampling

brainpy.math module provides convenient random sampling functions. This module contains some simple random data generation methods, some permutation and distribution functions, and random generator functions. Here I just give several examples.

  • brainpy.math.random.rand(d0, d1, ..., dn)

This function of random module is used to generate random numbers or values in a given shape.

bm.random.rand(5, 2)
JaxArray(DeviceArray([[0.99398685, 0.39656162],
                      [0.5161425 , 0.81978667],
                      [0.31676686, 0.083583  ],
                      [0.16560888, 0.40949285],
                      [0.43086028, 0.22965682]], dtype=float32))
  • brainpy.math.random.randn(d0, d1, ..., dn)

This function of random module return a sample from the “standard normal” distribution.

bm.random.randn(5, 2)
JaxArray(DeviceArray([[-0.7701253 ,  0.00965391],
                      [-0.11301948,  0.1409633 ],
                      [-0.11914475,  0.068143  ],
                      [ 1.6409276 ,  1.3378068 ],
                      [ 1.8202178 , -0.37819335]], dtype=float32))
  • brainpy.math.random.randint(low, high[, size, dtype])

This function of random module is used to generate random integers from inclusive(low) to exclusive(high).

bm.random.randint(0, 3, size=10)  
JaxArray(DeviceArray([0, 1, 1, 2, 0, 1, 0, 2, 0, 2], dtype=int32))
  • brainpy.math.random.random([size])

This function of random module is used to generate random floats number in the half-open interval [0.0, 1.0).

bm.random.random((3, 2))
JaxArray(DeviceArray([[0.76483357, 0.559957  ],
                      [0.50227726, 0.41693842],
                      [0.65068877, 0.8199152 ]], dtype=float32))

brainpy.math module also provides permutation functions.

  • brainpy.math.random.shuffle()

This function is used for modifying a sequence in-place by shuffling its contents.

bm.random.shuffle( bm.arange(10) )  
JaxArray(DeviceArray([8, 4, 9, 5, 7, 0, 3, 6, 1, 2], dtype=int32))
  • brainpy.math.random.permutation()

This function permute a sequence randomly or return a permuted range.

bm.random.permutation( bm.arange(10) )  
JaxArray(DeviceArray([2, 1, 0, 9, 5, 4, 7, 6, 8, 3], dtype=int32))

brainpy.math module also provides functions to sample distributions.

  • beta(a, b[, size])

This function is used to draw samples from a Beta distribution.

bm.random.beta(2, 3, 10) 
JaxArray(DeviceArray([0.48173192, 0.09183226, 0.5617174 , 0.4964077 , 0.5717186 ,
                      0.60861576, 0.3472139 , 0.58446443, 0.41256   , 0.07920451],            dtype=float32))
  • exponential([scale, size])

This function is used to draw sample from an exponential distribution.

bm.random.exponential(1, 10) 
JaxArray(DeviceArray([1.5618182 , 0.18306465, 1.0619484 , 1.2519189 , 0.6019476 ,
                      1.0401233 , 0.37211612, 0.06336975, 3.796705  , 0.03766083],            dtype=float32))

More sampling methods please see random sampling functions.

And more

Moreover, there are many other methods we can use to create tensors, including:

  • Conversion from other Python structures (i.e. lists and tuples)

  • Intrinsic NumPy array creation functions (e.g. arange, ones, zeros, etc.)

  • Use of special library functions (e.g., random)

  • Replicating, joining, or mutating existing tensors

  • Reading arrays from disk, either from standard or custom formats

  • Creating arrays from raw bytes through the use of strings or buffers

Detail of these methods please see NumPy tutorial: Array creation. Most of these methods are supported in BrainPy.

Supported operations on tensor

All the operations in BrainPy are based on tensors. Therefore it is necessary to know what operations supported in each tensor object.

Basic operations

Arithmetic operators on tensors apply element-wise. Let’s take “+”, “-”, “*”, and “/” as examples.

We first create two tensors:

data = bm.array([1, 2])

data
JaxArray(DeviceArray([1, 2], dtype=int32))
ones = bm.ones(2)

ones
JaxArray(DeviceArray([1., 1.], dtype=float32))

data + ones
JaxArray(DeviceArray([2., 3.], dtype=float32))

data - ones
JaxArray(DeviceArray([0., 1.], dtype=float32))
data * data
JaxArray(DeviceArray([1, 4], dtype=int32))
data / data
JaxArray(DeviceArray([1., 1.], dtype=float32))

Aggregation functions can also be performed on tensors, like:

  • .min(): get the minimum element;

  • .max(): get the maximum element;

  • .sum(): get the summation;

  • .mean(): get the average;

  • .prod(): get the result of multiplying the elements together;

  • .std(): to get the standard deviation.

data = bm.array([1, 2, 3])

data.max()
DeviceArray(3, dtype=int32)
data.min()
DeviceArray(1, dtype=int32)
data.sum()
DeviceArray(6, dtype=int32)

It’s very common to want to aggregate along a row or column. You can specify on which axis you want the aggregation function to be computed. For example, you can find the maximum value within each column by specifying axis=0.

a = bm.array([[1, 2],
              [5, 3],
              [4, 6]])
a.max(axis=0)
JaxArray(DeviceArray([5, 6], dtype=int32))
a.max(axis=1)
JaxArray(DeviceArray([2, 5, 6], dtype=int32))

Broadcasting

Tensor operations are usually done on pairs of arrays on an element-by-element basis. In the simplest case, the two tensors must have exactly the same shape, as in the following example:

a = bm.array([1.0, 2.0, 3.0])
b = bm.array([2.0, 2.0, 2.0])

a * b
JaxArray(DeviceArray([2., 4., 6.], dtype=float32))

However, broadcasting rule relaxes this constraint when the tensor’ shapes meet certain constraints. The simplest broadcasting example occurs when an tensor and a scalar value are combined in an operation:

a = bm.array([1, 2])
b = 1.6

a * b
JaxArray(DeviceArray([1.6, 3.2], dtype=float32, weak_type=True))

Similarly, broadcasting can happens on matrix. Below is an example.

data = bm.array([[1, 2],
                 [3, 4],
                 [5, 6]])

ones_row = bm.array([[1, 1]])

data + ones_row
JaxArray(DeviceArray([[2, 3],
                      [4, 5],
                      [6, 7]], dtype=int32))

Under certain constraints, the smaller tensor can be “broadcast” across the larger tensor so that they have compatible shapes. Broadcasting provides a means of vectorizing tensor operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations.

Generally, the dimensions of two tensors are compatible when

  • they are equal, or

  • one of them is 1

  • one of them has less number of dimensions

If these conditions are not met, an error will happen.

For example, according to the broadcast rules, the following two shapes are compatible:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3
image = bm.random.random((256, 256, 3))
scale = bm.random.random(3)

_ = image + scale 
_ = image - scale 
_ = image * scale 
_ = image / scale 

These shapes are also compatible:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
A = bm.random.random((8, 1, 6, 1))
B = bm.random.random((7, 1, 5))

_ = A + B 
_ = A - B 
_ = A * B 
_ = A / B 

However, these examples of shapes do not broadcast:

A      (1d array):  3
B      (1d array):  4 # trailing dimensions do not match

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # second from last dimensions mismatched
A = bm.random.random((3,))
B = bm.random.random((4,))

try:
    _ = A + B
except Exception as e:
    print(e)
add got incompatible shapes for broadcasting: (3,), (4,).
A = bm.random.random((2, 1))
B = bm.random.random((8, 4, 3))

try:
    _ = A + B
except Exception as e:
    print(e)
Incompatible shapes for broadcasting: ((1, 2, 1), (8, 4, 3))

More details about broadcasting please see NumPy documentation: broadcasting.

Indexing, Slicing and Iterating

Any tensors can be indexed, sliced and iterated over, much like lists and other Python sequences. For examples:

a = bm.arange(10) ** 3

a
JaxArray(DeviceArray([  0,   1,   8,  27,  64, 125, 216, 343, 512, 729], dtype=int32))
a[2]
DeviceArray(8, dtype=int32)
a[2:5]
DeviceArray([ 8, 27, 64], dtype=int32)
# from start to position 6, exclusive, set every 2nd element to 1000,
# equivalent to a[0:6:2] = 1000

a[:6:2] = 1000

a
JaxArray(DeviceArray([1000,    1, 1000,   27, 1000,  125,  216,  343,  512,  729], dtype=int32))
a[::-1]  # reversed a
DeviceArray([ 729,  512,  343,  216,  125, 1000,   27, 1000,    1, 1000], dtype=int32)
for i in a:  # iterate a
    print(i**(1 / 3.))
10.000001
1.0
10.000001
3.0
10.000001
5.0000005
6.0000005
7.0000005
8.000001
9.000001

For multi-dimensional tensors, these indices should be given in a tuple separated by commas. For example,

b = bm.arange(20).reshape((5, 4))
b[2, 3]
DeviceArray(11, dtype=int32)
b[0:5, 1]  # each row in the second column of b
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[:, 1]    # equivalent to the previous example
DeviceArray([ 1,  5,  9, 13, 17], dtype=int32)
b[1:3, :]  # each column in the second and third row of b
DeviceArray([[ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

When fewer indices are provided than the number of axes, the missing indices are considered complete slices:

b[-1]   # the last row. Equivalent to b[-1, :]
DeviceArray([16, 17, 18, 19], dtype=int32)

You can also write this using dots as b[i, ...]. The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is an array with 5 axes, then

  • x[1, 2, ...] is equivalent to x[1, 2, :, :, :],

  • x[..., 3] to x[:, :, :, :, 3] and

  • x[4, ..., 5, :] to x[4, :, :, 5, :].

c = bm.arange(48).reshape((6, 4, 2))
c[1, ...]  # same as c[1, :, :] or c[1]
DeviceArray([[ 8,  9],
             [10, 11],
             [12, 13],
             [14, 15]], dtype=int32)
c[..., 2]  # same as c[:, :, 2]
DeviceArray([[ 1,  3,  5,  7],
             [ 9, 11, 13, 15],
             [17, 19, 21, 23],
             [25, 27, 29, 31],
             [33, 35, 37, 39],
             [41, 43, 45, 47]], dtype=int32)

Iterating over multidimensional tensors is done with respect to the first axis:

for row in b:
    print(row)
[0 1 2 3]
[4 5 6 7]
[ 8  9 10 11]
[12 13 14 15]
[16 17 18 19]

More methods or advanced indexing and index tricks please see NumPy tutorial: Indexing.

Mathematical functions

Tensors support many other functions, including

Most of these functions can be found in brainpy.math module. Let’s take a look at trigonometric, hyperbolic, rounding functions.

d = bm.linspace(0, 1, 10)

d
JaxArray(DeviceArray([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
                      0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],            dtype=float32))
# trigonometric functions

bm.sin(d)
JaxArray(DeviceArray([0.        , 0.11088263, 0.22039774, 0.32719472, 0.42995638,
                      0.5274154 , 0.6183698 , 0.7016979 , 0.7763719 , 0.84147096],            dtype=float32))
bm.arcsin(d)
JaxArray(DeviceArray([0.        , 0.11134101, 0.2240931 , 0.33983693, 0.46055397,
                      0.589031  , 0.7297277 , 0.8911225 , 1.0949141 , 1.5707964 ],            dtype=float32))
# hyperbolic functions

bm.sinh(d)
JaxArray(DeviceArray([0.        , 0.11133985, 0.22405571, 0.33954054, 0.45922154,
                      0.58457786, 0.7171585 , 0.8586021 , 1.0106566 , 1.1752012 ],            dtype=float32))
# rounding functions

bm.round(d)
JaxArray(DeviceArray([0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], dtype=float32))
# sum function

bm.sum(d)
DeviceArray(5., dtype=float32)

Variables

In BrainPy, the JIT compilation for class objects relies on Variable. In this section, we are going to understand:

  • what is Variable?

  • the subtypes of Variable?

import brainpy as bp
import brainpy.math as bm

Variable

brainpy.math.Variable is a pointer refers to a tensor. It stores the value of the tensor. The concrete value in a Variable can be changed. If a tensor is labeled as a Variable, it means that it is a dynamical variable, and its data can be changed.

During the JIT compilation, the tensors which are not marked as Variable will be compiled as static data. The change of the tensor will not work, or cause an error.

  • Create a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable, for example:

bm.use_backend('numpy')

a1 = bm.random.random(5)
a2 = bm.Variable(a1)

a1, a2 
(array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]),
 Variable([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844]))
bm.use_backend('jax')

b1 = bm.random.random(5)
b2 = bm.Variable(b1)

b1, b2
(JaxArray(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)),
 Variable(DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)))
  • Access the value in a Variable

The concrete value of a Variable can be obtained through .value.

a2.value
array([0.33133975, 0.12552793, 0.93629203, 0.77514911, 0.22587844])
(a2.value == a1).all()
True
b2.value
DeviceArray([0.70530474, 0.99841356, 0.815271  , 0.926391  , 0.84018004],            dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
  • Supported operations on a Variable

A Variable support almost all the operations for a tensor. Actually, brainpy.math.Variable is a subclass of brainpy.math.ndarray.

isinstance(a2, bp.math.numpy.ndarray)
True
isinstance(b2, bp.math.jax.ndarray)
True
isinstance(b2, bp.math.jax.JaxArray)
True
# `bp.math.jax.ndarray` is an alias for `bp.math.jax.JaxArray` in 'jax' backend

bp.math.jax.ndarray is bp.math.jax.JaxArray
True

Note

In ‘jax’ backend, after performing any operation on a Variable, the resulting value will be a JaxArray (bp.math.jax.ndarray is an alias for bp.math.jax.JaxArray in ‘jax’ backend). This means that the Variable can only be used to refer to a value.

b2 + 1.
JaxArray(DeviceArray([1.7053047, 1.9984136, 1.815271 , 1.926391 , 1.84018  ], dtype=float32))
b2 ** 2
JaxArray(DeviceArray([0.4974548 , 0.9968296 , 0.66466683, 0.8582003 , 0.7059025 ],            dtype=float32))
bp.math.jax.floor(b2)
JaxArray(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))
  • Subtypes of Variable

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar and brainpy.math.Parameter. Subtypes can also be customized and extended by the user. We are going to talk about this.

TrainVar

brainpy.math.TrainVar is a trainable variable (a subclass of brainpy.math.Variable). Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purpose.

bm.use_backend('numpy')

a = bm.random.rand(4)

a, bm.TrainVar(a)
(array([0.81515042, 0.40363449, 0.89924935, 0.29827197]),
 TrainVar([0.81515042, 0.40363449, 0.89924935, 0.29827197]))
bm.use_backend('jax')

b = bm.random.rand(4)

b, bm.TrainVar(b)
(JaxArray(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)),
 TrainVar(DeviceArray([0.4008    , 0.21182728, 0.9596069 , 0.6859863 ], dtype=float32)))

Parameter

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets method (please see Base class).

bm.use_backend('numpy')

a = bm.random.rand(1)

a, bm.Parameter(a)
(array([0.5776296]), Parameter([0.5776296]))
bm.use_backend('jax')

b = bm.random.rand(1)

b, bm.Parameter(b)
(JaxArray(DeviceArray([0.61128676], dtype=float32)),
 Parameter(DeviceArray([0.61128676], dtype=float32)))

RandomState

In ‘jax’ backend, brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. This is because the RandomState in ‘jax’ backend must store the dynamically changed key information. Every time after a RandomState performs a random sampling, the “key” will change. For example,

bm.use_backend('jax')

state = bm.random.RandomState(seed=1234)

state
RandomState(DeviceArray([   0, 1234], dtype=uint32))
# perform a "random" sampling 
state.random(1)

# the value changed
state
RandomState(DeviceArray([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling 
state.sample(1)

# the value changed too
state
RandomState(DeviceArray([1076515368, 3893328283], dtype=uint32))

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.jax.random module: DEFAULT.

bm.random.DEFAULT
RandomState(DeviceArray([2580684476, 2503630841], dtype=uint32))

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState(DeviceArray([     0, 654321], dtype=uint32))

Base Class

In this section, we are going to talk about:

  • Base class for BrainPy ecosystem,

  • Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

Base

The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatic differentiated. Or we can say, any class objects want to JIT compile or auto differentiate must inherent from brainpy.Base.

A Base object can have many variables, children Base objects, integrators, and methods. For example, let’s implement a FitzHugh-Nagumo neuron model.

class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

Naming system

Every Base object has a unique name. You can specify a unique name when you instantiate a Base class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x7f4a7406bd60> has a used name "Y".

When you instance a Base class without “name” specification, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions

Three important collection functions are implemented for each Base object. Specifically, they are:

  • nodes(): to collect all instances of Base objects, including children nodes in a node.

  • ints(): to collect all integrators defined in the Base node and in its children nodes.

  • vars(): to collect all variables defined in the Base node and in its children nodes.

All integrators can be collected through one method Base.ints(). The result container is a Collector.

fhn = FHN(10)
ints = fhn.ints()

ints
{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>}
type(ints)
brainpy.base.collector.Collector

Similarly, all variables in a Base object can be collected through Base.vars(). The returned container is a TensorCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
type(vars)
brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method Base.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>,
 'FHN2': <__main__.FHN at 0x7f4a7406b5e0>}
type(nodes)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        conn = bm.ones((num1, num2), dtype=bool)
        self.conn = bm.fill_diagonal(conn, False) * w

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the first layer’s spike times a connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all integrators in the network with .ints() :

net.ints()
{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

Or, retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN4.V': Variable([0., 0., 0., 0., 0.]),
 'FHN4.spike': Variable([False, False, False, False, False]),
 'FHN4.w': Variable([0., 0., 0., 0., 0.])}

Or, retrieve all nodes (instances of Base class) with .nodes():

net.nodes()
{'FHN3': <__main__.FHN at 0x7f4a74077670>,
 'FHN4': <__main__.FHN at 0x7f4a740771c0>,
 'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x7f4a74077790>}
Absolute path

It’s worthy to note that there are two types of ways to access variables, integrators, and nodes. They are “absolute” path and “relative” path. The default way is the absolute path.

“Absolute” path means that all keys in the resulting Collector (Base.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK44', 'FHN2')

Calling .nodes() returns models’ name and models.

fhn.nodes().keys()
dict_keys(['RK44', 'FHN2'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN4.V': Variable([0., 0., 0., 0., 0.]),
 'FHN4.spike': Variable([False, False, False, False, False]),
 'FHN4.w': Variable([0., 0., 0., 0., 0.])}
Relative path

Variables, integrators, and nodes can also be accessed by relative path. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x7f4a74077670>

Relative path preserves the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x7f4a74077790>,
 'pre': <__main__.FHN at 0x7f4a74077670>,
 'post': <__main__.FHN at 0x7f4a740771c0>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x7f4a74077670>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>}

Variables can also relatively inferred from the model. For example, all variables one can relatively accessed from net are:

net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'pre.spike': Variable([False, False, False, False, False, False, False, False]),
 'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'post.V': Variable([0., 0., 0., 0., 0.]),
 'post.spike': Variable([False, False, False, False, False]),
 'post.w': Variable([0., 0., 0., 0., 0.])}

While, variables relatively accessed from the view of net.post are:

net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.]),
 'spike': Variable([False, False, False, False, False]),
 'w': Variable([0., 0., 0., 0., 0.])}
Elements in containers

To avoid surprising unintended behaviors, collection functions don’t look for elements in list, dict or any container structure.

class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class define a list of variables, and a dict of children nodes. However, they can not be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x7f4a7401b2b0>}

Fortunately, in BrianPy, we provide implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Any variable registered in implicit_vars, or any integrator or node registered in implicit_nodes can be retrieved by collection functions. Let’s make a try.

class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.implicit_vars = {f'v{i}': v for i, v in enumerate(self.all_vars)}  # must be a dict
        self.implicit_nodes = {k: v for k, v in self.sub_nodes.items()}  # must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()
{'T1': <__main__.FHN at 0x7f4a740777c0>,
 'T2': <__main__.FHN at 0x7f4a74077a90>,
 'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74077340>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74011040>,
 'AnotherTest0': <__main__.AnotherTest at 0x7f4a740157f0>}
# This model has five Base objects (seen above), 
# each FHN node has three variables, 
# moreover, this model has two implicit variables.

t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'T2.V': Variable([0., 0., 0., 0., 0.]),
 'T2.spike': Variable([False, False, False, False, False]),
 'T2.w': Variable([0., 0., 0., 0., 0.]),
 'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.]),
 'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.])}

Saving and loading

Because Base.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (more details please see Saving and Loading). Specifically, they are implemented by Base.save_states() and Base.load_states().

Save
Base.save_states(PATH, [vars])

Model exporting in BrainPy supports various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
# Unknown file format will cause error

try:
    net.save_states('./data/net.xxx')
except Exception as e:
    print(type(e).__name__, ":", e)
BrainPyError : Unknown file format: ./data/net.xxx. We only supports ['.h5', '.hdf5', '.npz', '.pkl', '.mat']
Load

Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector

Collection functions returns an brainpy.Collector. This class is a dictionary that maps names to elements. It has some useful methods.

subset()

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, Base.nodes() returns all instances of Base class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches to the given cls.

unique()

It’s a common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations twice or more to the same elements (e.g., apply gradients and update weights).

Therefore, Collector provides method Collector.unique() to handle this automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.]),
 'A_shared.a': Variable([0., 0., 0., 0., 0.]),
 'A_shared.source.a': Variable([0., 0., 0., 0., 0.])}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.])}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x7f4a74049190>,
 'A': <__main__.ModelA at 0x7f4a740490a0>,
 'A_shared': <__main__.SharedA at 0x7f4a74049550>,
 'A_shared.source': <__main__.ModelA at 0x7f4a740490a0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x7f4a74049190>,
 'A': <__main__.ModelA at 0x7f4a740490a0>,
 'A_shared': <__main__.SharedA at 0x7f4a74049550>}

update()

Collector is a dict. But, it has means to catch potential conflicts during assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key maps to a different value. If yes, an error will raise.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()

If you try to replace the old key with the new value, you should use Collector.replace(old_key, new_value) function.

tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
tc.replace('a', bm.ones(3))

tc
{'a': array([1., 1., 1.])}

TensorCollector

TensorCollector is subclass of Collector, but it is specifically to collect tensors.

Compilation

In this section, we are going to talk about the concept of the code compilation to accelerate your model running performance.

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

jit()

We have talked about the mechanism of JIT compilation for class objects in NumPy backend. In this section, we try to understand how to apply JIT when you are using JAX backend.

jax.jit() is excellent, while it only supports pure functions. brainpy.math.jax.jit() is based on jax.jit(), but extends its ability to just-in-time compile your class objects.

JIT for pure functions

First, brainpy.math.jax.jit() can just-in-time compile your functions.

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * bm.where(x > 0, x, alpha * bm.exp(x) - alpha)

x = bm.random.normal(size=(1000000,))
%timeit selu(x)
2.86 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
selu_jit = bm.jit(selu) # jit accleration

%timeit selu_jit(x)
346 µs ± 21.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

JIT for class objects

Moreover, brainpy.math.jax.jit() is powerful to just-in-time compile your class objects. The constraints for class object JIT are:

  • The JIT target must be a subclass of brainpy.Base.

  • Dynamically changed variables must be labeled as brainpy.math.Variable.

  • Variable changes must be made in-place.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters
        self.dimension = dimension

        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u

num_dim, num_points = 10, 200000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr = LogisticRegression(10)
%timeit lr(points, labels)
2.77 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
lr_jit = bm.jit(lr)

%timeit lr_jit(points, labels)
1.29 ms ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

JIT mechanism

The mechanism of JIT compilation is that BrainPy automatically transforms your class methods into functions.

brainpy.math.jax.jit() receives a dyn_vars argument, which denotes the dynamically changed variables. If you do not provide it, BrainPy will automatically detect them by calling Base.vars(). Once get “dyn_vars”, BrainPy will treat “dyn_vars” as function arguments, thus making them able to dynamically change.

import types

isinstance(lr_jit, types.FunctionType)  # "lr" is class, while "lr_jit" is a function
True

Therefore, the secrete of brainpy.math.jax.jit() is providing “dyn_vars”. No matter your target is a class object, a method in the class object, or a pure function, if there are dynamically changed variables, you just pack them into brainpy.math.jax.jit() as “dyn_vars”. Then, all the compilation and acceleration will be handled by BrainPy automatically.

Example 1: JIT a class method

class Linear(bp.Base):
    def __init__(self, n_in, n_out):
        super(Linear, self).__init__()
        self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
        self.b = bm.TrainVar(bm.zeros(n_out))
    
    def update(self, x):
        return x @ self.w + self.b
x = bm.zeros(10)
l = Linear(10, 3)

This time, we mark “w” and “b” as dynamically changed variables.

update1 = bm.jit(l.update, dyn_vars=[l.w, l.b])  # make 'w' and 'b' dynamically change
update1(x)  # x is 0., b is 0., therefore y is 0.
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
l.b[:] = 1.  # change b to 1, we expect y will be 1 too

update1(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))

This time, we only mark “w” as dynamically changed variables. We will find also modify “b”, the results will not change.

update2 = bm.jit(l.update, dyn_vars=[l.w])  # make 'w' dynamically change

update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
l.b[:] = 2.  # change b to 2, while y will not be 2
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))

Example 2: JIT a function

Now, we change the above “Linear” object to a function.

n_in = 10;  n_out = 3

w = bm.TrainVar(bm.random.random((n_in, n_out)))
b = bm.TrainVar(bm.zeros(n_out))

def update(x):
    return x @ w + b

If we do not provide dyn_vars, “w” and “b” will be compiled as constant values.

update1 = bm.jit(update)
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))
b[:] = 1.  # modify the value of 'b' will not 
           # change the result, because in the 
           # jitted function, 'b' is already 
           # a constant
update1(x)
JaxArray(DeviceArray([0., 0., 0.], dtype=float32))

Provide “w” and “b” as dyn_vars will make them dynamically changed again.

update2 = bm.jit(update, dyn_vars=(w, b))
update2(x)
JaxArray(DeviceArray([1., 1., 1.], dtype=float32))
b[:] = 2.  # change b to 2, while y will not be 2
update2(x)
JaxArray(DeviceArray([2., 2., 2.], dtype=float32))

RandomState

We have talked about RandomState in Variables section. We said that it is also a Variable. Therefore, if your functions have used the default RandomState (brainpy.math.jax.random.DEFAULT), you should add it into the dyn_vars scope of the function. Otherwise, they will be treated as constants and the jitted function will always return the same value.

def function():
    return bm.random.normal(0, 1, size=(10,))
f1 = bm.jit(function)

f1() == f1()
JaxArray(DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                       True,  True], dtype=bool))

The correct way to make JIT for this function is:

bm.random.seed(1234)

f2 = bm.jit(function, dyn_vars=bm.random.DEFAULT)

f2() == f2()
JaxArray(DeviceArray([False, False, False, False, False, False, False, False,
                      False, False], dtype=bool))

Example 3: JIT a neural network

Now, let’s use SGD to train a neural network with JIT acceleration. Here we will use the autograd function brainpy.math.jax.grad(), which will be detailed out in the next section.

class LinearNet(bp.Base):
    def __init__(self, n_in, n_out):
        super(LinearNet, self).__init__()

        # weights
        self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
        self.b = bm.TrainVar(bm.zeros(n_out))
        self.r = bm.TrainVar(bm.random.random((n_out, 1)))
    
    def update(self, x):
        h = x @ self.w + self.b
        return h @ self.r
    
    def loss(self, x, y):
        predict = self.update(x)
        return bm.mean((predict - y) ** 2)


ln = LinearNet(100, 200)

# provide the variables want to update
opt = bm.optimizers.SGD(lr=1e-6, train_vars=ln.vars()) 

# provide the variables require graidents
f_grad = bm.grad(ln.loss, grad_vars=ln.vars(), return_value=True)  


def train(X, Y):
    grads, loss = f_grad(X, Y)
    opt.update(grads)
    return loss

# JIT the train function 
train_jit = bm.jit(train, dyn_vars=ln.vars() + opt.vars())
xs = bm.random.random((1000, 100))
ys = bm.random.random((1000, 1))

for i in range(30):
    loss  = train_jit(xs, ys)
    print(f'Train {i}, loss = {loss:.2f}')
Train 0, loss = 103.74
Train 1, loss = 63.50
Train 2, loss = 40.54
Train 3, loss = 27.44
Train 4, loss = 19.97
Train 5, loss = 15.71
Train 6, loss = 13.27
Train 7, loss = 11.89
Train 8, loss = 11.09
Train 9, loss = 10.64
Train 10, loss = 10.38
Train 11, loss = 10.24
Train 12, loss = 10.15
Train 13, loss = 10.10
Train 14, loss = 10.08
Train 15, loss = 10.06
Train 16, loss = 10.05
Train 17, loss = 10.05
Train 18, loss = 10.04
Train 19, loss = 10.04
Train 20, loss = 10.04
Train 21, loss = 10.04
Train 22, loss = 10.04
Train 23, loss = 10.04
Train 24, loss = 10.04
Train 25, loss = 10.04
Train 26, loss = 10.04
Train 27, loss = 10.04
Train 28, loss = 10.04
Train 29, loss = 10.04

Static arguments

Static arguments are arguments that are treated as static/constant in the jitted function.

Numerical arguments used in condition syntax (bool value or resulting bool value), and strings must be marked as static. Otherwise, an error will raise.

@bm.jit
def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x
try:
    f(1.)
except Exception as e:
    print(type(e), e)
<class 'jax._src.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at <ipython-input-70-14a993a83941>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Simply speaking, arguments resulting boolean values must be declared as static arguments. In brainpy.math.jax.jit() function, if can set the names of static arguments.

def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x

f_jit = bm.jit(f, static_argnames=('x', ))
f_jit(x=1.)
DeviceArray(3., dtype=float32, weak_type=True)

However, it’s worthy noting that calling the jitted function with different values for these static arguments will trigger recompilation. Therefore, declaring static arguments may be suitable to the following situations:

  1. Boolean arguments.

  2. Arguments only have several possible values.

If the argument value change significantly, you’d better not to declare it as static.

For more information, please refer to jax.jit API.

vmap()

Coming soon.

pmap()

Coming soon.

Differentiation

In this section, we are going to talk about how to make auto differentiation on your functions and class objects with ‘jax’ backend. In nowadays machine learning systems, computing and using gradients are common in various situations. So, we try to understand

  • how to calculate derivatives of arbitrary complex functions,

  • how to compute high-order gradients.

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

All autodiff functions in BrainPy support pure functions and class objects.

grad()

brainpy.math.jax.grad() takes a function/object and returns a new function which computes the gradient of the original function/object.

Pure functions

For pure function, the gradient is taken with respect to the first argument:

def f(a, b):
    return a * 2 + b

grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32)

However, this can be controlled via the argnums argument.

grad_f2 = bm.grad(f, argnums=(0, 1))

grad_f2(2., 1.)
(DeviceArray(2., dtype=float32), DeviceArray(1., dtype=float32))

Class objects

For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars argument:

class F(bp.Base):
    def __init__(self):
        super(F, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        vv = ab2 + c
        return vv.mean()
    
f = F()

The grad_vars can be a JaxArray, or a list/tuple/dict of JaxArray.

bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': TrainVar(DeviceArray([2.], dtype=float32)),
 'F0.b': TrainVar(DeviceArray([2.], dtype=float32))}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
[TrainVar(DeviceArray([2.], dtype=float32)),
 TrainVar(DeviceArray([2.], dtype=float32))]

If there are values dynamically changed in the gradient function, you can provide them in the dyn_vars argument.

class F2(bp.Base):
    def __init__(self):
        super(F2, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab = ab * 2
        self.a.value = ab
        return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
TrainVar(DeviceArray([2.], dtype=float32))

Also, if you are interested with the gradient of the input value, please use argnums argument. For this situation, calling the gradient function will return (grads_of_grad_vars, *grads_of_args).

class F3(bp.Base):
    def __init__(self):
        super(F3, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c, d):
        ab = self.a * self.b
        ab = ab * 2
        return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_arg0 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grads_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_arg0, grad_of_arg1 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
print("grad_of_arg1 :", grad_of_arg1)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grad_of_arg0 : 3.0
grad_of_arg1 : 10.0

Actually, we recommend you to provide any dynamically changed variables (no matter them are updated in the gradient function) in the dyn_vars argument.

Auxiliary data

Usually, we want to get the value of the loss, or, we want to return some intermediate variables during the gradient computation. For them situation, users can set has_aux=True to return auxiliary data, and set return_value=True to return loss value.

# return loss

grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)

print('grad: ', grad)
print('loss: ', loss)
grad:  TrainVar(DeviceArray([2.], dtype=float32))
loss:  12.0
class F4(bp.Base):
    def __init__(self):
        super(F4, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        loss = (ab + c).mean()
        return loss, (ab, ab2)
    

f4 = F4()
    
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)

print('grad: ', grad)
print('aux_data: ', aux_data)
grad:  TrainVar(DeviceArray([1.], dtype=float32))
aux_data:  (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32)))

Note: Any function wants to compute gradients through brainpy.math.jax.grad() must return a scalar value. Otherwise an error will raise.

try:
    bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
    print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output was [0. 0.].
# this is right
bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray(DeviceArray([0.5, 0.5], dtype=float32))

If you want to take gradients for a vector-output values, please use brainpy.math.jax.jacobian() function.

jacobian()

Coming soon.

hessian()

Coming soon.

Control Flows

In this section, we are going to talk about how to build structured control flows in ‘jax’ backend. These control flows include

  • for loop syntax,

  • while loop syntax,

  • and condition syntax.

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

In JAX, the control flow syntaxes are not easy to use. Users must transform the intuitive Python control flows into structured control flows.

Based on JaxArray provided in BrainPy, we try to present a better syntax to make control flows.

make_loop()

brainpy.math.jax.make_loop() is used to generate a for-loop function when you are using JaxArray.

Let’s image your requirement: you are using several JaxArray (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes, your body function return something, and you also want to gather the return values. With Python syntax, your requirement is equivalent to


def for_loop_function(body_fun, dyn_vars, out_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' and 'out_vars' 
    # are updated in 'body_fun()'
    results = body_fun(x)
    ys.append([out_vars, results])
  return ys

In BrainPy, using brainpy.math.jax.make_loop() you can define this logic like:


loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=False)

hist_of_out_vars = loop_fun(xs)

Or,


loop_fun = brainpy.math.jax.make_loop(body_fun, dyn_vars, out_vars, has_return=True)

hist_of_out_vars, hist_of_return_vars = loop_fun(xs)

Let’s implement a recurrent network to illustrate how to use this function.

class RNN(bp.DynamicalSystem):
  def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
    super(RNN, self).__init__(**kwargs)

    # parameters
    self.n_in = n_in
    self.n_h = n_h
    self.n_out = n_out
    self.n_batch = n_batch
    self.g = g

    # weights
    self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
    self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
    self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
    self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
    self.b_ro = bm.TrainVar(bm.zeros((n_out,)))

    # variables
    self.h = bm.Variable(bm.random.random((n_batch, n_h)))

    # function
    self.predict = bm.make_loop(self.cell,
                                dyn_vars=self.vars(),
                                out_vars=self.h,
                                has_return=True)

  def cell(self, x):
    self.h[:] = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
    o = self.h @ self.w_ro + self.b_ro
    return o


rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)

In the above RNN model, we define a body function RNN.cell for later for-loop over input values. The loop function is defined as self.predict with bm.make_loop(). We care about the history values of “self.h” and the readout value “o”, so we set out_vars = self.h and has_return=True.

xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape  # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape  # the shape should be (num_time, ) + o.shape
(100, 5, 3)

If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs), where “xs” can be a JaxArray, list/tuple/dict of JaxArray. For examples:

a = bm.zeros(10)

def body(x):
    x1, x2 = x  # "x" is a tuple/list of JaxArray
    a.value += (x1 + x2)

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))
a = bm.zeros(10)

def body(x):  # "x" is a dict of JaxArray
    a.value += x['a'] + x['b']

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
JaxArray(DeviceArray([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
                      [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
                      [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
                      [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
                      [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
                      [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
                      [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
                      [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
                      [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
                      [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32))

dyn_vars, out_vars, xs and body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_varsand body function returns. If has_return=True, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns). If no values are interested, please set out_vars=None, and the loop function only returns hist_of_out_vars.

make_while()

brainpy.math.jax.make_while() is used to generate a while-loop function when you are using JaxArray. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.jax.make_while() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which does not support return values:


while cond_fun(x):
    body_fun(x)

where x is the external input which is not iterated. All the iterated variables should be marked as JaxArray. All JaxArray used in cond_fun and body_fun should be declared in a dyn_vars variable.

Let’s look an example:

i = bm.zeros(1)
counter = bm.zeros(1)

def cond_f(x): 
    return i[0] < 10

def body_f(x):
    i.value += 1.
    counter.value += i

loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])

In the above, we try to implement a sum from 0 to 10. We use two JaxArray i and counter.

loop()
counter
JaxArray(DeviceArray([55.], dtype=float32))
i
JaxArray(DeviceArray([10.], dtype=float32))

make_cond()

brainpy.math.jax.make_cond() is used to generate a condition function when you are using JaxArray. It supports the following condition logic:


if True:
    true statements 
else: 
    false statements

When using brainpy.math.jax.make_cond() , true statements should be wrapped as a true_fun function which implements logics under true assert (no return), and false statements should be wrapped as a false_fun function which implements logics under false assert (also does not support return values):


if True:
    true_fun(x)
else:
    false_fun(x)

All the JaxArray used in true_fun and false_fun should be declared in the dyn_vars argument. x is also used to receive the external input value.

Let’s make a try:

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += 1

def false_f(x): b.value -= 1

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])

Here, we have two tensors. If true, tensor a add 1; if false, tensor b subtract 1.

cond(pred=True)

a, b
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(True)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([0., 0.], dtype=float32)))
cond(False)

a, b
(JaxArray(DeviceArray([2., 2.], dtype=float32)),
 JaxArray(DeviceArray([-1., -1.], dtype=float32)))

Or, we define a conditional case which depends on the external input.

a = bm.zeros(2)
b = bm.ones(2)

def true_f(x):  a.value += x

def false_f(x): b.value -= x

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([1., 1.], dtype=float32)))
cond(False, 5.)

a, b
(JaxArray(DeviceArray([10., 10.], dtype=float32)),
 JaxArray(DeviceArray([-4., -4.], dtype=float32)))

Optimizers

Gradient descent is one of the most popular algorithms to perform optimization. By far, gradient descent optimizers, combined with the loss function, are the key pieces that enable machine learning to work for your data. In this section, we are going to understand:

  • how to use optimizers in BrainPy?

  • how to customize your own optimizer?

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')
import matplotlib.pyplot as plt

Optimziers

The basic optimizer class in BrainPy is

bm.optimizers.Optimizer
brainpy.math.jax.optimizers.Optimizer

Following are some optimizers in BrainPy:

  • SGD

  • Momentum

  • Nesterov momentum

  • Adagrad

  • Adadelta

  • RMSProp

  • Adam

Users can also extent their own optimizers easily.

Generally, an Optimizer initialization receives a learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

  • lr can be a float, or an instance of bm.optimizers.Scheduler.

  • train_vars should be a dict of JaxArray.

Here we launch a SGD optimizer.

a = bm.ones((5, 4))
b = bm.zeros((3, 3))

op = bm.optimizers.SGD(lr=0.001, train_vars={'a': a, 'b': b})

When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.99970317, 0.99958736, 0.9991706 , 0.99929893],
                      [0.99986506, 0.9994412 , 0.9996797 , 0.9995855 ],
                      [0.99980134, 0.999285  , 0.99970514, 0.99927545],
                      [0.99907184, 0.9993837 , 0.99917775, 0.99953413],
                      [0.9999124 , 0.99908406, 0.9995969 , 0.9991523 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-5.8195234e-04, -4.3874790e-04, -3.3398748e-05],
                      [-5.7411409e-04, -7.0666044e-04, -9.4130711e-04],
                      [-7.1995187e-04, -1.1736620e-04, -9.5254736e-04]],            dtype=float32))

You can process gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}

grads_pre
{'a': JaxArray(DeviceArray([[0.99927866, 0.03028023, 0.8803668 , 0.64568734],
                       [0.64056313, 0.04791141, 0.7399359 , 0.87378347],
                       [0.96773326, 0.7771431 , 0.9618045 , 0.8374212 ],
                       [0.64901245, 0.24517596, 0.06224799, 0.6327405 ],
                       [0.31687486, 0.6385107 , 0.9160483 , 0.67039466]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.14722073, 0.52626574, 0.9817407 ],
                       [0.7333363 , 0.39472723, 0.82928896],
                       [0.7657701 , 0.93165004, 0.88332164]], dtype=float32))}
grads_post = bm.clip_by_norm(grads_pre, 1.)

grads_post
{'a': JaxArray(DeviceArray([[0.31979424, 0.00969043, 0.28173944, 0.20663615],
                       [0.20499626, 0.01533285, 0.23679803, 0.27963263],
                       [0.3096989 , 0.24870528, 0.30780157, 0.26799577],
                       [0.20770025, 0.07846245, 0.01992092, 0.20249282],
                       [0.1014079 , 0.20433943, 0.2931584 , 0.2145431 ]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.0666547 , 0.23826863, 0.4444865 ],
                       [0.33202055, 0.17871413, 0.37546346],
                       [0.34670505, 0.4218078 , 0.39992693]], dtype=float32))}
op.update(grads_post)

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.9993834 , 0.99957764, 0.99888885, 0.9990923 ],
                      [0.9996601 , 0.9994259 , 0.9994429 , 0.9993059 ],
                      [0.99949163, 0.99903625, 0.99939734, 0.99900746],
                      [0.9988641 , 0.99930525, 0.99915785, 0.99933165],
                      [0.999811  , 0.99887973, 0.99930376, 0.9989378 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-0.00064861, -0.00067702, -0.00047789],
                      [-0.00090613, -0.00088537, -0.00131677],
                      [-0.00106666, -0.00053917, -0.00135247]], dtype=float32))

Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, you should add variables in Optimzier.vars() onto the dyn_vars in bm.jit().

op.vars()  # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable(DeviceArray([2], dtype=int32))}
bm.optimizers.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables
{'Constant1.step': Variable(DeviceArray([0], dtype=int32)),
 'Momentum0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Momentum0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}
bm.optimizers.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables
{'Constant2.step': Variable(DeviceArray([0], dtype=int32)),
 'Adam0.a_m': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_m': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32)),
 'Adam0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}

Creating a custom optimizer

If you intend to create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

  • __init__(): init function which receives learning rate (lr) and trainable variables (train_vars).

  • update(grads): update function to compute the updated parameters.

For example,

class CustomizeOp(bm.optimizers.Optimizer):
    def __init__(self, lr, train_vars, *params, **other_params):
        super(CustomizeOp, self).__init__(lr, train_vars)
        
        # customize your initialization
        
    def update(self, grads):
        # customize your update logic
        pass

Schedulers

Scheduler seeks to adjust the learning rate during training by reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

For example, we setup an exponential decay scheduler, in which the learning rate will decay exponentially along the train step.

sc = bm.optimizers.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
    plt.plot(steps, rates)
    plt.xlabel('Train Step')
    plt.ylabel('Learning Rate')
    plt.show()
steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)
_images/64443300e1c2b3e2d79f5bac153d22cec2d3647b157d81cfb04ff3af63474e02.png

After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr
<brainpy.math.jax.optimizers.Constant at 0x2ab375a3700>

One can get the current learning rate value by calling Scheduler.__call__(i=None).

  • If i is not provided, the learning rate value will be evaluated at the built-in training step.

  • Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()
0.001

In BrainPy, several common used learning rate schedulers are used:

  • Constant

  • ExponentialDecay

  • InverseTimeDecay

  • PolynomialDecay

  • PiecewiseConstant

# InverseTimeDecay scheduler

rates = bm.optimizers.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)
_images/b1359c52f506e2691e3d196817e2acb5241dfb605da26bc7110397b70cc86754.png
# PolynomialDecay scheduler

rates = bm.optimizers.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)
_images/ab606d7e6eb405f5f38194833ad33c7925e7726f3dae80d165ad7d8d45263964.png

Creating a custom scheduler

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

  • __init__(): the init function.

  • __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bm.optimizers.Scheduler):
    def __init__(self, lr, *params, **other_params):
        super(CustomizeScheduler, self).__init__(lr)
        
        # customize your initialization
        
    def __call__(self, i=None):
        # customize your update logic
        pass

Numerical Integrator

Numerical Solvers for ODEs

@Chaoming Wang

Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to using numerical methods for solving such differential equations.

In this section, I will illustrate how to define ordinary differential quations (ODEs), and how to define the numerical integration methods for them in BrainPy.

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

%matplotlib inline

How to define ODE functions?

BrainPy provides a convenient and intuitive way to define ODE systems. For the ODE

$$ {dx \over dt} = f_1(x, t, y, p_1)\ {dy \over dt} = f_2(y, t, x, p_2) $$

we can define this system as a Python function:

def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

where t denotes the current time, p1 and p2 which after the t are represented as parameters needed in this system, and x and y passed before t denotes the dynamical variables. In the function body, the derivative for each variable can be customized by the user’s need f1 and f2. Finally, we return the corresponding derivatives dx and dy with the order the same as the variables in the function arguments.

For each variable x or y, it can be a scalar (var_type = bp.integrators.SCALAR_VAR), a vector/matrix (var_type = bp.integrators.POP_VAR), or a system (var_type = bp.integrators.SYSTEM_VAR). Here, the “system” means that the argument x denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:

def diff(xy, t, p1, p2):
    x, y = xy
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return bm.array([dx, dy])

How to define the numerical integration for ODEs?

After the definition of ODE functions, the numerical integration of these functions are very easy in BrainPy. We just need put a decorator (bp.odeint).

@bp.odeint
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

After wrapping the derivative function by bp.odeint, the function becomes an instance of ODEintegrator.

isinstance(diff, bp.ode.ODEIntegrator)
True

bp.odeint receives several arguments:

  • “method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.

diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x7fa92852d700>
  • “dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.

diff.dt
0.1
  • “show_code”: bool, indicates whether show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.

@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  x_new = x + dx_k1 * dt * 1
  y_new = y + dy_k1 * dt * 1
  return x_new, y_new

{'f': <function diff at 0x7fa928517e50>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x7fa92853d040>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  k2_x_arg = x + dt * dx_k1 * 0.5
  k2_y_arg = y + dt * dy_k1 * 0.5
  k2_t_arg = t + dt * 0.5
  dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
  k3_x_arg = x + dt * dx_k2 * 0.5
  k3_y_arg = y + dt * dy_k2 * 0.5
  k3_t_arg = t + dt * 0.5
  dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
  k4_x_arg = x + dt * dx_k3
  k4_y_arg = y + dt * dy_k3
  k4_t_arg = t + dt
  dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
  x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
  y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
  return x_new, y_new

{'f': <function diff at 0x7fa928517a60>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa92853d910>

Two Illustrated Examples

Example 1: FitzHugh–Nagumo model

Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:

$$ \begin{split} \tau {\dot {w}}&=v+a-bw\ {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+I_{\rm {ext}} \end{split} $$

For this FHN model, we can code it in BrainPy like this:

@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
    dw = (V + a - b * w) / tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,

a=0.7;   b=0.8;   tau=12.5;   Iext=1.

the solution of the FHN model between 0 and 100 ms can be approximated by

import matplotlib.pyplot as plt

%matplotlib inline
hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
    V, w = integral(V, w, t, Iext, a, b, tau)
    hist_V.append(V)

plt.plot(hist_times, hist_V)
[<matplotlib.lines.Line2D at 0x7fa92846fa30>]
_images/f4a3990a4271e5474771d72c4f49aa99af890ff31b8839032870c93f3be0886e.png
Example 2: Hodgkin–Huxley model

Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h) are used for modeling the initiation and propagration of the action potential. Specificaly, they are governed by the following equations:

$$ \begin{aligned} C_{m} \frac{d V}{d t} &=-\bar{g}{\mathrm{K}} n^{4}\left(V-V{K}\right)- \bar{g}{\mathrm{Na}} m^{3} h\left(V-V{N a}\right)-\bar{g}{l}\left(V-V{l}\right)+I_{s y n} \ \frac{d m}{d t} &=\alpha_{m}(V)(1-m)-\beta_{m}(V) m \ \frac{d h}{d t} &=\alpha_{h}(V)(1-h)-\beta_{h}(V) h \ \frac{d n}{d t} &=\alpha_{n}(V)(1-n)-\beta_{n}(V) n \end{aligned} $$

In BrainPy, such dynamical system can be coded as:

@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m

    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h

    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n

    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C

    return dVdt, dmdt, dhdt, dndt

Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:

Iext=10.;   ENa=50.;   EK=-77.;   EL=-54.387
C=1.0;      gNa=120.;  gK=36.;    gL=0.03
hist_times = bm.arange(0, 100, 0.01)
hist_V, hist_m, hist_h, hist_n = [], [], [], []
V, m, h, n = 0., 0., 0., 0.
for t in hist_times:
    V, m, h, n = integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C)
    hist_V.append(V)
    hist_m.append(m)
    hist_h.append(h)
    hist_n.append(n)

plt.subplot(211)
plt.plot(hist_times, hist_V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(hist_times, hist_m, label='m')
plt.plot(hist_times, hist_h, label='h')
plt.plot(hist_times, hist_n, label='n')
plt.legend()
plt.show()
_images/b0f33e0ceb7b6f77c6101d60f2c75476bc2f75b0187facb8218d3237372d2ffe.png

Provided ODE Numerical Solvers

BrainPy provides several numerical methods for ordinary differential equations (ODEs). Specifically, we provide explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler method for ODE numerical integration.

Explicit Runge-Kutta methods for ODEs

The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:

Methods

Keywords

Euler

euler

Midpoint

midpoint

Heun’s second-order method

heun2

Ralston’s second-order method

ralston2

RK2

rk2

RK3

rk3

RK4

rk4

Heun’s third-order method

heun3

Ralston’s third-order method

ralston3

Third-order Strong Stability Preserving Runge-Kutta

ssprk3

Ralston’s fourth-order method

ralston4

Runge-Kutta 3/8-rule fourth-order method

rk4_38rule

Users can utilize these methods by specify the method option in brainpy.odeint() with their corresponding keyword. For example:

@bp.odeint(method='rk4')
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa9283b6100>

Or, you can directly instance your favorite integrator like:

@bp.ode.RK4
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa92852dc70>
def derivative(v, t, p):
    # do something
    return v

int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x7fa9283b63a0>
Adaptive Runge-Kutta methods for ODEs

The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if $error > tol$, then replace $dt$ with $dt_{new}$ and repeat the step. Therefore, adaptive RK methods allow the varied step size. In BrainPy, the following adaptive RK methods are provided:

Methods

keywords

Runge–Kutta–Fehlberg 4(5)

rkf45

Runge–Kutta–Fehlberg 1(2)

rkf12

Dormand–Prince method

rkdp

Cash–Karp method

ck

Bogacki–Shampine method

bs

Heun–Euler method

heun_euler

In default, the above methods are not adaptive, unless users provide a keyword adaptive=True in brainpy.odeint(). When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt will be appended in the functional arguments. Moreover, the tolerance tol for stepsize adjustment can also be controlled by users. Let’s take the Lorenz system as the example:

# adaptively adjust stepsize

@bp.odeint(method='rkf45', 
           adaptive=True, # active the "adaptive" option
           tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
    # should provide one more argument "dt" when using the adaptive rk method
    x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)  
    hist_x.append(x)
    hist_y.append(y)
    hist_z.append(z)
    hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')

fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()
_images/67f82fc13bedd432d01fae0acf93b6350cf907593ef5a25ca6e54cf8521b6c4c.png _images/82f395a134fae7be968d894ccfdee30eed67424744f0728bb25a2da82d044439.png
Exponential Euler methods for ODEs

Finally, BrainPy provides Exponential integrators for ODEs. For you linear ODE systems, we highly recommend you to to use Exponential Euler methods.

Methods

keywords

Exponential Euler

exponential_euler

For a linear system,

$$ {dy \over dt} = A - By $$

the exponential Euler schema is given by:

$$ y(t+dt) = y(t) e^{-Bdt} + {A \over B}(1 - e^{-Bdt}) $$

As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.

In BrainPy, in order to automatically find out the linear part, we will utilize the SymPy to parse user defined functions. Therefore, ones need install sympy first when using exponential Euler method.

What’s interesting, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. In the next, you will find that by using Exponential Euler method, the numerical step can be enlarged much to save the computation time.

$$ \begin{aligned} C_{m}{\frac {d V}{dt}}&= -\left[{\bar {g}}{\text{K}}n^{4} + {\bar {g}}{\text{Na}}m^{3}h + {\bar {g}}{l} \right] V +{\bar {g}}{\text{K}}n^{4} V_{K} + {\bar {g}}{\text{Na}}m^{3}h V{Na} + {\bar {g}}{l} V{l} + I_{syn} \ {\frac {dm}{dt}} &= \left[-\alpha _{m}(V)-\beta _{m}(V)\right]m + \alpha _{m}(V) \ {\frac {dh}{dt}} &= \left[-\alpha _{h}(V)-\beta _{h}(V)\right]h + \alpha _{h}(V) \ {\frac {dn}{dt}} &= \left[-\alpha _{n}(V)-\beta _{n}(V)\right]n + \alpha _{n}(V) \ \end{aligned} $$

Iext=10.;   ENa=50.;   EK=-77.;   EL=-54.387
C=1.0;      gNa=120.;  gK=36.;    gL=0.03
def derivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m

    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h

    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n

    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C

    return dVdt, dmdt, dhdt, dndt
def run(method, Iext=10., dt=0.1):
    hist_times = bm.arange(0, 100, dt)
    hist_V, hist_m, hist_h, hist_n = [], [], [], []
    V, m, h, n = 0., 0., 0., 0.
    for t in hist_times:
        V, m, h, n = method(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C)
        hist_V.append(V)
        hist_m.append(m)
        hist_h.append(h)
        hist_n.append(n)

    plt.subplot(211)
    plt.plot(hist_times, hist_V, label='V')
    plt.legend()
    plt.subplot(212)
    plt.plot(hist_times, hist_m, label='m')
    plt.plot(hist_times, hist_h, label='h')
    plt.plot(hist_times, hist_n, label='n')
    plt.legend()

Euler Method

int1 = bp.odeint(f=derivative, method='euler', dt=0.1)

run(int1, Iext=10, dt=0.1)
<ipython-input-25-35d6bfdac53f>:2: RuntimeWarning: overflow encountered in exp
  alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
<ipython-input-25-35d6bfdac53f>:3: RuntimeWarning: overflow encountered in exp
  beta = 4.0 * bm.exp(-(V + 65) / 18)
<ipython-input-25-35d6bfdac53f>:6: RuntimeWarning: overflow encountered in exp
  alpha = 0.07 * bm.exp(-(V + 65) / 20.)
<ipython-input-25-35d6bfdac53f>:7: RuntimeWarning: overflow encountered in exp
  beta = 1 / (1 + bm.exp(-(V + 35) / 10))
<ipython-input-25-35d6bfdac53f>:10: RuntimeWarning: overflow encountered in exp
  alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
<ipython-input-25-35d6bfdac53f>:11: RuntimeWarning: overflow encountered in exp
  beta = 0.125 * bm.exp(-(V + 65) / 80)
<ipython-input-25-35d6bfdac53f>:4: RuntimeWarning: invalid value encountered in double_scalars
  dmdt = alpha * (1 - m) - beta * m
<ipython-input-25-35d6bfdac53f>:8: RuntimeWarning: invalid value encountered in double_scalars
  dhdt = alpha * (1 - h) - beta * h
<ipython-input-25-35d6bfdac53f>:12: RuntimeWarning: invalid value encountered in double_scalars
  dndt = alpha * (1 - n) - beta * n
_images/9e6ae3dcfe6295d40a4c309b509bc904b7867a022d3680663bceffae94098b4c.png
int2 = bp.odeint(f=derivative, method='euler', dt=0.02)

run(int2, Iext=10, dt=0.02)
_images/f4e4b076bcd1eed5dcbbcc47aac52d4b0381fe932cb6adca29b8a5686df50c09.png

RK4 Method

int3 = bp.odeint(f=derivative, method='rk4', dt=0.1)

run(int3, Iext=10, dt=0.1)
_images/63ae96e8a09f6aff0aa6b0bd1a4cf1876b94fa5e44ed69379c2996018f7fcfca.png
int4 = bp.odeint(f=derivative, method='rk4', dt=0.2)

run(int4, Iext=10, dt=0.2)
<ipython-input-25-35d6bfdac53f>:2: RuntimeWarning: overflow encountered in exp
  alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
<ipython-input-25-35d6bfdac53f>:3: RuntimeWarning: overflow encountered in exp
  beta = 4.0 * bm.exp(-(V + 65) / 18)
<ipython-input-25-35d6bfdac53f>:6: RuntimeWarning: overflow encountered in exp
  alpha = 0.07 * bm.exp(-(V + 65) / 20.)
<ipython-input-25-35d6bfdac53f>:7: RuntimeWarning: overflow encountered in exp
  beta = 1 / (1 + bm.exp(-(V + 35) / 10))
<ipython-input-25-35d6bfdac53f>:10: RuntimeWarning: overflow encountered in exp
  alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
<ipython-input-25-35d6bfdac53f>:11: RuntimeWarning: overflow encountered in exp
  beta = 0.125 * bm.exp(-(V + 65) / 80)
<ipython-input-25-35d6bfdac53f>:4: RuntimeWarning: invalid value encountered in double_scalars
  dmdt = alpha * (1 - m) - beta * m
<ipython-input-25-35d6bfdac53f>:8: RuntimeWarning: invalid value encountered in double_scalars
  dhdt = alpha * (1 - h) - beta * h
<ipython-input-25-35d6bfdac53f>:12: RuntimeWarning: invalid value encountered in double_scalars
  dndt = alpha * (1 - n) - beta * n
<ipython-input-25-35d6bfdac53f>:17: RuntimeWarning: invalid value encountered in double_scalars
  dVdt = (- I_Na - I_K - I_leak + Iext) / C
_images/22d0eadcda2cb0393562cf8d5e1a0edfeca607757b9513850cb5c3198bb6ae97.png

Exponential Euler Method

int5 = bp.odeint(f=derivative, method='exponential_euler', dt=0.2)

run(int5, Iext=10, dt=0.2)
_images/806eeb5857f1756a98d9371a0eb54fd9331d2b98016367d7b39f0387fb5b1b15.png

Numerical Solvers for SDEs

@Chaoming Wang

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

import brainpy as bp

bp.__version__
'1.1.0'
import matplotlib.pyplot as plt

%matplotlib inline

How to define SDE functions?

For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by

$$ \begin{aligned} d X_{t}&=f\left(X_{t}, t, p_1\right) d t+g\left(X_{t}, t, p_2\right) d W_{t} \quad (1) \end{aligned} $$

where $X_t = X(t)$ is the realization of a stochastic process or random variable, $f(X_t, t)$ is the drift coefficient, $g(X_t, t)$ denotes the diffusion coefficient, the stochastic process $W_t$ is called Wiener process.

For this SDE system, we can define two Python funtions $f$ and $g$ to represent it.

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df

Same with the ODE functions, the arguments before $t$ denotes the random variables, while the arguments defined after $t$ represents the parameters. For the SDE function with scalar noise, the size of the return data $dg$ and $df$ should be the same. For example, $df \in R^d, dg \in R^d$.

However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:

$$ dX_t=f(X_t)dt+\sum_{\alpha=1}^{m}g_{\alpha }(X_t)dW_t ^{\alpha} $$

For such $m$-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of $dg$ has one more dimension. For example, $df \in R^{d}, dg \in R^{d \times m}$.

How to define the numerical integration for SDEs?

Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get

$$ \begin{aligned} X_{t}&=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} \quad (2) \end{aligned} $$

In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “$\circ$” to distinct it from the former Ito integral.

$$ \begin{aligned} d X_{t} &=f\left(X_{t}, t\right) d t+g\left(X_{t}, t\right) \circ d W_{t} \ X_{t} &=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) \circ d W_{s} \quad (3) \end{aligned} $$

The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as

$$ \begin{split} \int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} &=\lim {h \rightarrow 0} \sum{k=0}^{m-1} g\left(X_{\tau_{k}}, \tau_{k}\right)\left(W\left(t_{k+1}\right)-W\left(t_{k}\right)\right) \ \mathrm{where} \quad & h = t_{k+1} - t_{k} \ & \tau_k = (1-\lambda)t_k +\lambda t_{k+1} \end{split} $$

  • In the stochastic integral of the Ito SDE, $\lambda=0$, thus $\tau_k=t_k$;

  • In the definition of the Stratonovich integral, $\lambda=0.5$, thus $\tau_k=(t_{k+1} + t_{k}) / 2$.

In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type in decorator bp.sdeint. intg_type can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.

Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

Or, it can be expressed as:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, g=g_part, method='milstein')

However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d, m)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, 
                     g=g_part, 
                     method='milstein', 
                     intg_type=bp.integrators.STRA_SDE, 
                     wiener_type=bp.integrators.SCALAR_WIENER)

Example: Noisy Lorenz system

Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:

$$ \begin{array}{l} \frac{d x}{dt}&=\sigma(y-x) &+ px*\xi_x \ \frac{d y}{dt}&=x(\rho-z)-y &+ py*\xi_y\ \frac{d z}{dt}&=x y-\beta z &+ pz*\xi_z \end{array} $$

sigma = 10; beta = 8/3; 
rho = 28;   p = 0.1

def lorenz_g(x, y, z, t):
    return p * x, p * y, p * z

def lorenz_f(x, y, z, t):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz

lorenz = bp.sdeint(f=lorenz_f, 
                   g=lorenz_g, 
                   intg_type=bp.integrators.ITO_SDE,
                   wiener_type=bp.integrators.SCALAR_WIENER,
                   dt=0.005)
hist_times = bp.math.arange(0, 50, 0.005)
hist_x, hist_y, hist_z = [], [], []
x, y, z = 1., 1., 1.
for t in hist_times:
    x, y, z = lorenz(x, y, z, t)
    hist_x.append(x)
    hist_y.append(y)
    hist_z.append(z)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
Text(0.5, 0, 'z')
_images/36a1532c2ca610d0600f93a7b2074577a29009fc66f433ae5f0f7650e621bf81.png

Supported SDE Numerical Methods

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

Methods

Keywords

Ito SDE support

Stratonovich SDE support

Scalar Wiener support

Vector Wiener support

Strong SRK scheme: SRI1W1

srk1w1_scalar

Yes

Yes

Strong SRK scheme: SRI2W1

srk2w1_scalar

Yes

Yes

Strong SRK scheme: KlPl

KlPl_scalar

Yes

Yes

Euler method

euler

Yes

Yes

Yes

Yes

Heun method

heun

Yes

Yes

Yes

Derivative-free Milstein

milstein

Yes

Yes

Yes

Yes

Exponential Euler

exponential_euler

Yes

Yes

Yes

Dynamics Simulation

Efficient Synaptic Computation

@Chaoming Wang

In a real project, the most of simulation time spends on the computation of the synapses. Therefore, figuring out what is the most efficient way to do synaptic computation is a necessary step to accelerate your computational project. Here, let’s take an E/I balance network as an example to illustrate how to code an efficient synaptic computation.

import brainpy as bp
import brainpy.math as bm

bm.use_backend('numpy')
%matplotlib inline
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

The E/I balance network COBA is adopted from (Vogels & Abbott, 2005) [1].

# Parameters for network structure
num = 4000
num_exc = int(num * 0.75)
num_inh = int(num * 0.25)

Neuron Model

In COBA network, each integrate-and-fire neuron is characterized by a time constant, $\tau$ = 20 ms, and a resting membrane potential, $V_{rest}$ = -60 mV. Whenever the membrane potential crosses a spiking threshold of -50 mV, an action potential is generated and the membrane potential is reset to the resting potential, where it remains clamped for a 5 ms refractory period. The membrane voltages are calculated as follows:

$$ \tau {dV \over dt} = (V_{rest} - V) + g_{exc}(E_{exc} - V) + g_{inh}(E_{inh} - V) $$

where reversal potentials are $E_{exc} = 0$ mV and $E_{inh} = -80$ mV.

# Parameters for the neuron
tau = 20  # ms
Vt = -50  # mV
Vr = -60  # mV
El = -60  # mV
ref_time = 5.0  # refractory time, ms
I = 20.
class LIF(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)

    # variables
    self.V = bm.Variable(bm.zeros(size))
    self.input = bm.Variable(bm.zeros(size))
    self.spike = bm.Variable(bm.zeros(size, dtype=bool))
    self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, V, t, Iexc):
    dV = (Iexc + El - V) / tau
    return dV

  def update(self, _t, _dt):
    for i in range(self.num):
      self.spike[i] = 0.
      if (_t - self.t_last_spike[i]) > ref_time:
        V = self.integral(self.V[i], _t, self.input[i])
        if V >= Vt:
          self.V[i] = Vr
          self.spike[i] = 1.
          self.t_last_spike[i] = _t
        else:
          self.V[i] = V
      self.input[i] = I

Synapse Model

In COBA network, when a neuron fires, the appropriate synaptic variable of its postsynaptic targets are increased, $g_{exc} \gets g_{exc} + \Delta g_{exc}$ for an excitatory presynaptic neuron and $g_{inh} \gets g_{inh} + \Delta g_{inh}$ for an inhibitory presynaptic neuron. Otherwise, these parameters obey the following equations:

$$ \tau_{exc} {dg_{exc} \over dt} = -g_{exc} \quad \quad (1) \ \tau_{inh} {dg_{inh} \over dt} = -g_{inh} \quad \quad (2) $$

with synaptic time constants $\tau_{exc} = 5$ ms, $\tau_{inh} = 10$ ms, $\Delta g_{exc} = 0.6$ and $\Delta g_{inh} = 6.7$.

# Parameters for the synapse
tau_exc = 5  # ms
tau_inh = 10  # ms
E_exc = 0.  # mV
E_inh = -80.  # mV
delta_exc = 0.6  # excitatory synaptic weight
delta_inh = 6.7  # inhibitory synaptic weight
def run_net(neu_model, syn_model):
  E = neu_model(num_exc, monitors=['spike'])
  E.V[:] = bm.random.randn(num_exc) * 5. + Vr
    
  I = neu_model(num_inh, monitors=['spike'])
  I.V[:] = bm.random.randn(num_inh) * 5. + Vr
  
  E2E = syn_model(pre=E, post=E, conn=bp.connect.FixedProb(0.02),
                  tau=tau_exc, weight=delta_exc, E=E_exc)
  E2I = syn_model(pre=E, post=I, conn=bp.connect.FixedProb(0.02),
                  tau=tau_exc, weight=delta_exc, E=E_exc)
  I2E = syn_model(pre=I, post=E, conn=bp.connect.FixedProb(0.02),
                  tau=tau_inh, weight=delta_inh, E=E_inh)
  I2I = syn_model(pre=I, post=I, conn=bp.connect.FixedProb(0.02),
                  tau=tau_inh, weight=delta_inh, E=E_inh)

  net = bp.Network(E, I, E2E, E2I, I2E, I2I)
  net = bm.jit(net)
  t = net.run(100., report=0.1)

  fig, gs = bp.visualize.get_figure(row_num=5, col_num=1, row_len=1, col_len=10)
  fig.add_subplot(gs[:4, 0])
  bp.visualize.raster_plot(E.mon.ts, E.mon.spike, xlim=(0, 100.), ylabel='E Group', xlabel='')
  fig.add_subplot(gs[4, 0])
  bp.visualize.raster_plot(I.mon.ts, I.mon.spike, xlim=(0, 100.), ylabel='I Group', show=True)

  return t

Matrix-based connection

The matrix-based synaptic connection is one of the most intuitive way to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat') (details please see Synaptic Connectivity). Each connection matrix is an array with the shape of (num_pre, num_post), like

Based on conn_mat, the updating logic of the above synapses can be coded as:

class SynMat1(bp.TwoEndConn):

  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynMat1, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # p1: connections
    self.conn = conn(pre.size, post.size)
    self.conn_mat = self.conn.requires('conn_mat')

    # variables
    self.g = bm.Variable(bm.zeros(self.conn_mat.shape))
    
    # function
    self.integral = bp.odeint(self.derivative)
  
  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    spike_on_syn = bm.expand_dims(self.pre.spike, 1) * self.conn_mat  # p2
    self.g[:] += spike_on_syn * self.weight  # p3
    self.post.input[:] += bm.sum(self.g, axis=0) * (self.E - self.post.V)  # p4

In the above defined SynMat1 class, at “p1” line we requires a “conn_mat” structure for the later synaptic computation; at “p2” we get spikes for each synaptic connections according to “conn_mat” and “presynaptic spikes”; then at “p3”, the spike-triggered synaptic variables are added onto its postsynaptic targets; at final “p4” code line, all connected synaptic values are summed to get the current effective conductance by np.sum(self.g, axis=0).

Now, let’s inspect the performance of this matrix-based synapse.

t_syn_mat1 = run_net(neu_model=LIF, syn_model=SynMat1)
Compilation used 8.9806 s.
Start running ...
Run 10.0% used 11.922 s.
Run 20.0% used 24.164 s.
Run 30.0% used 36.259 s.
Run 40.0% used 48.411 s.
Run 50.0% used 60.550 s.
Run 60.0% used 72.775 s.
Run 70.0% used 85.545 s.
Run 80.0% used 98.326 s.
Run 90.0% used 110.973 s.
Run 100.0% used 123.404 s.
Simulation is done in 123.404 s.
_images/fd57f5afd8eaa0ae913236fb0469db8f068358aa61d8f257d6805425d9083b09.png

This matrix-based synapse structure is very inefficient, because 99.9% time were wasted on the synaptic computation. We can inspect this by only running the neuron group models.

group = bm.jit(LIF(num, monitors=['spike']))
group.V[:] = bm.random.randn(num) * 5. + Vr

group.run(100., inputs=('input', 5.), report=True)
Compilation used 0.1588 s.
Start running ...
Run 100.0% used 0.027 s.
Simulation is done in 0.027 s.
0.02666616439819336

As you can see, the neuron group only spends 0.026 s to run. After normalized by the total running time 120+ s, the neuron group running only accounts for about 0.02 %.

Event-based updating

The inefficiency in the above matrix-based computation comes from the horrendous waste of time on synaptic computation. First, it is uncommon for a neuron to generate a spike; Second, in a group of neuron, the generated spikes (self.pre.spike) are usually sparse. Therefore, at many time points, there are many zeros in self.pre.spike, which results self.g add many unnecessary zeros (self.g += spike_on_syn * self.weight).

Alternatively, we can update self.g only when the pre-synaptic neuron produces a spike event (this is called as the event-based updating method):

class SynMat2(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynMat2, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.conn_mat = self.conn.requires('conn_mat')

    # variables
    self.g = bm.Variable(bm.zeros(self.conn_mat.shape))

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    # p1
    for pre_i, spike in enumerate(self.pre.spike):
      if spike:
        self.g[pre_i] += self.conn_mat[pre_i] * self.weight
    self.post.input[:] += bm.sum(self.g, axis=0) * (self.E - self.post.V)

Compared to SynMat1, we replace “p2” and “p3” in SynMat1 with “p1” in SynMat2. Now, the updating logic is only when the pre-synaptic neuron emits a spike (if spike), the connected post-synaptic state g will be updated (self.g[pre_i] += self.conn_mat[pre_i] * self.weight).

t_syn_mat2 = run_net(neu_model=LIF, syn_model=SynMat2)
Compilation used 8.1212 s.
Start running ...
Run 10.0% used 5.830 s.
Run 20.0% used 11.736 s.
Run 30.0% used 17.713 s.
Run 40.0% used 23.714 s.
Run 50.0% used 29.624 s.
Run 60.0% used 35.564 s.
Run 70.0% used 41.559 s.
Run 80.0% used 47.508 s.
Run 90.0% used 53.456 s.
Run 100.0% used 59.436 s.
Simulation is done in 59.436 s.
_images/104e7dfd47dbd9d968ecffb91e5feaa310d641b5dea2dab7063f889d8eaf8b4d.png

Such event-based matrix connection boosts the running speed nearly 2 times, but it’s not good enough.

Vector-based connection

Matrix-based synaptic computation may be straightforward, but can cause severe wasted RAM memory and inefficient computation. Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need $10^8$ floats to save the synaptic state, and at each update step, you need do computation on $10^8$ floats. Actually, the number of values you really needed is only $10^7$. See, there is a huge memory waste and computing resource inefficiency.

pre_ids and post_ids

An effective method to solve this problem is to use vector to store the connectivity between neuron groups and the corresponding synaptic states. For the above defined connectivity conn_mat, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids,

In such a way, we only need two vectors (pre_ids and post_ids, each has $10^7$ floats) to store the synaptic connectivity. And, at each time step, we just need update a synaptic state vector with $10^7$ floats.

class SynVec1(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec1, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
    self.num = len(self.pre_ids)

    # variables
    self.g = bm.Variable(bm.zeros(self.num))
    
    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    for syn_i in range(self.num):
      # p1: update
      pre_i = self.pre_ids[syn_i]
      if self.pre.spike[pre_i]:
        self.g[syn_i] += self.weight
      # p2: output
      post_i = self.post_ids[syn_i]
      self.post.input[post_i] += self.g[syn_i] * (self.E - self.post.V[post_i])

In SynVec1 class, we first update the synaptic state with “p1” code block, in which the synaptic state self.g[syn_i] is updated when the pre-synaptic neuron generates a spike (if self.pre.spike[pre_i]); then, at “p2” code block, we output the synaptic states onto the post-synaptic neurons.

t_syn_vec1 = run_net(neu_model=LIF, syn_model=SynVec1)
Compilation used 2.8904 s.
Start running ...
Run 10.0% used 0.124 s.
Run 20.0% used 0.240 s.
Run 30.0% used 0.358 s.
Run 40.0% used 0.473 s.
Run 50.0% used 0.586 s.
Run 60.0% used 0.698 s.
Run 70.0% used 0.812 s.
Run 80.0% used 0.928 s.
Run 90.0% used 1.040 s.
Run 100.0% used 1.150 s.
Simulation is done in 1.150 s.
_images/3f822ca1e44cffebfd78fe8f2a034ece055a38a5e0490ff31cd5b1a97d64b87f.png

Great! Transform the matrix-based connection into the vector-based connection makes us get a huge speed boost. However, there also exists redundant part in SynVec1 class. This is because a pre-synaptic neuron may connect to many post-synaptic neurons and thus at each step updating we will judge a pre-synaptic neuron whether generates a spike many times (self.pre.spike[pre_i]).

pre2syn and post2syn

In order to solve the above problem, here we create another two synaptic structures pre2syn and post2syn to help us retrieve the synapse states which connected with the pre-synaptic neuron $i$ and the post-synaptic neuron $j$.

In a pre2syn list, each pre2syn[i] stores the synaptic state indexes projected from the pre-synaptic neuron $i$.

Similarly, we can create a post2syn list to indicate the connections between synapses and post-synaptic neurons. For each post-synaptic neuron $j$, post2syn[j] stores the indexes of synaptic elements which connected to the post neuron $j$.

Based on these connectivity mappings, we can define another version of synapse model by using pre2syn and post2syn:

class SynVec2(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec2, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.pre2syn, self.post2syn = self.conn.requires('pre_ids', 'pre2syn', 'post2syn')
    self.num = len(self.pre_ids)

    # variables
    self.g = bm.Variable(bm.zeros(self.num))

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    # p1: update
    for pre_i in range(self.pre.num):
      if self.pre.spike[pre_i]:
        for syn_i in self.pre2syn[pre_i]:
          self.g[syn_i] += self.weight
    # p2: output
    for post_i in range(self.post.num):
      for syn_i in self.post2syn[post_i]:
        self.post.input[post_i] += self.g[syn_i] * (self.E - self.post.V[post_i])

In this SynVec2 class, at “p1” code-block, we update synaptic states by the for-loop with the size of pre-synaptic number. If the pre-synaptic neuron elicits a spike self.pre.spike[pre_i], we will for-loop its connected synaptic states by for syn_i in self.pre2syn[pre_i]. In such a way, we only need to judge the pre-synaptic neuron pre_i spike state once. Similarly, at “p2” code-block, the synaptic output is also implemented with the post-synaptic neuron for-loop.

t_syn_vec2 = run_net(neu_model=LIF, syn_model=SynVec2)
Compilation used 3.2760 s.
Start running ...
Run 10.0% used 0.125 s.
Run 20.0% used 0.252 s.
Run 30.0% used 0.385 s.
Run 40.0% used 0.513 s.
Run 50.0% used 0.640 s.
Run 60.0% used 0.780 s.
Run 70.0% used 0.919 s.
Run 80.0% used 1.049 s.
Run 90.0% used 1.181 s.
Run 100.0% used 1.308 s.
Simulation is done in 1.308 s.
_images/9b37507c5dceea25e796f6c91573d39db3e5ec7440a30bb43b6300ec733999aa.png

We only got a similar speed performance. This is because the optimization of the “update” block has run its course. Currently, the most of the running costs spend on the “output” block.

pre2post and post2pre

Notice that for this kind of synapse model, the synaptic states $g$ onto a post-synaptic neuron can be modeled together. This is because the synaptic state evolution according to the differential equation (1) and (2) after the pre-synaptic spikes can be superposed. This means that we can declare a synaptic state self.g with the shape of post.num, not the shape of the synapse number.

In order to achieve this goal, we create another two synaptic structures (pre2post and post2pre) which establish the direct mapping between the pre-synaptic neurons and the post-synaptic neurons. pre2post contains the connected post-synaptic neurons indexes, in which pre2post[i] retrieves the post neuron ids projected from pre-synaptic neuron $i$. post2pre contains the pre-synaptic neurons indexes, in which post2pre[j] retrieves the pre-syanptic neuron ids which project to post-synaptic neuron $j$.

class SynVec3(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec3, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre2post = self.conn.requires('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(post.num))

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    # p1: update
    for pre_i in range(self.pre.num):
      if self.pre.spike[pre_i]:
        for post_i in self.pre2post[pre_i]:
          self.g[post_i] += self.weight
    # p2: output
    self.post.input[:] += self.g * (self.E - self.post.V)

In SynVec3 class, we require a pre2post structure, and then at “p1” code-block, when the pre-synaptic neuron pre_i emits a spike, the connected post-synaptic neurons’ state self.g[post_i] will increase the conductance.

t_syn_vec3 = run_net(neu_model=LIF, syn_model=SynVec3)
Compilation used 4.1941 s.
Start running ...
Run 10.0% used 0.006 s.
Run 20.0% used 0.014 s.
Run 30.0% used 0.020 s.
Run 40.0% used 0.029 s.
Run 50.0% used 0.038 s.
Run 60.0% used 0.045 s.
Run 70.0% used 0.051 s.
Run 80.0% used 0.059 s.
Run 90.0% used 0.067 s.
Run 100.0% used 0.075 s.
Simulation is done in 0.075 s.
_images/b32c1c060b8c74bbba0478a57a9268508403823e5351bf838fafb4a96b399a0a.png

Yeah, the running speed gets a huge boosting, which demonstrates the super effectiveness of this kind of synaptic computation.

pre_slice and post_slice

However, it is not perfect. This is because pre2syn, post2syn, pre2post and post2pre are all the data with the list type, which can not be directly deployed to GPU devices. What the GPU device prefers are only arrays.

To solve this problem, we, instead, can create a post_slice connection structure which stores the start and the end position on the synpase state for each connected post-synaptic neuron $j$. post_slice can be implemented by aligning the pre ids according to the sequential post id $0, 1, 2, …$ (look the following illustrating figure). For each post neuron $j$, start, end = post_slice[j] retrieves the start/end position of the connected synapse states.

Therefore, an alternative updating logic of pre2syn and post2syn (in SynVec2 class) can be replaced by post_slice and pre_ids:

class SynVec4(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec4, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.post_slice = self.conn.requires('pre_ids', 'post_slice')
    self.num = len(self.pre_ids)

    # variables
    self.g = bm.Variable(bm.zeros(self.num))

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    # p1: update
    for syn_i in range(self.num):
      pre_i = self.pre_ids[syn_i]
      if self.pre.spike[pre_i]:
        self.g[syn_i] += self.weight
    # p2: output
    for post_i in range(self.post.num):
      start, end = self.post_slice[post_i]
      # for syn_i in range(start, end):
      self.post.input[post_i] += self.g[start: end].sum() * (self.E - self.post.V[post_i])
t_syn_vec4 = run_net(neu_model=LIF, syn_model=SynVec4)
Compilation used 3.3534 s.
Start running ...
Run 10.0% used 0.103 s.
Run 20.0% used 0.207 s.
Run 30.0% used 0.314 s.
Run 40.0% used 0.432 s.
Run 50.0% used 0.536 s.
Run 60.0% used 0.635 s.
Run 70.0% used 0.739 s.
Run 80.0% used 0.842 s.
Run 90.0% used 0.946 s.
Run 100.0% used 1.050 s.
Simulation is done in 1.050 s.
_images/2b14e006e709eee9a956588f1ee993ef05b597935fa74b4d765cffff44e69f74.png

Similarly, a connection mapping pre_slice can also be implemented, in which for each pre-synaptic neuron $i$, start, end = pre_slice[i] retrieves the start/end position of the connected synapse states.

Moreover, an alternative updating logic of pre2post (in SynVec3 class) can also be replaced by pre_slice and post_ids:

class SynVec5(bp.TwoEndConn):
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec5, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_slice, self.post_ids = self.conn.requires('pre_slice', 'post_ids')

    # variables
    self.g = bm.Variable(bm.zeros(post.num))

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    self.g[:] = self.integral(self.g, _t)
    # p1: update
    for pre_i in range(self.pre.num):
      if self.pre.spike[pre_i]:
        start, end = self.pre_slice[pre_i]
        for post_i in self.post_ids[start: end]:
          self.g[post_i] += self.weight
    # p2: output
    self.post.input[:] += self.g * (self.E - self.post.V)
t_syn_vec5 = run_net(neu_model=LIF, syn_model=SynVec5)
Compilation used 4.2100 s.
Start running ...
Run 10.0% used 0.005 s.
Run 20.0% used 0.011 s.
Run 30.0% used 0.017 s.
Run 40.0% used 0.024 s.
Run 50.0% used 0.031 s.
Run 60.0% used 0.038 s.
Run 70.0% used 0.045 s.
Run 80.0% used 0.051 s.
Run 90.0% used 0.057 s.
Run 100.0% used 0.064 s.
Simulation is done in 0.064 s.
_images/10f7f808100ee1022d2bbe1d83d23bdc7508b5653da2e75ba82f0919e2ed00b2.png

Speed comparison

In this tutorial, we introduce nine different synaptic connection structures:

  1. conn_mat : The connection matrix with the shape of (pre_num, post_num).

  2. pre_ids: The connected pre-synaptic neuron indexes, a vector with the shape pf syn_num.

  3. post_ids: The connected post-synaptic neuron indexes, a vector with the shape pf syn_num.

  4. pre2syn: A list (with the length of pre_num) contains the synaptic indexes connected by each pre-synaptic neuron. pre2syn[i] denotes the synapse ids connected by the pre-synaptic neuron $i$.

  5. post2syn: A list (with the length of post_num) contains the synaptic indexes connected by each post-synaptic neuron. post2syn[j] denotes the synapse ids connected by the post-synaptic neuron $j$.

  6. pre2post: A list (with the length of pre_num) contains the post-synaptic indexes connected by each pre-synaptic neuron. pre2post[i] retrieves the post neurons connected by the pre neuron $i$.

  7. post2pre: A list (with the length of post_num) contains the pre-synaptic indexes connected by each post-synaptic neuron. post2pre[j] retrieves the pre neurons connected by the post neuron $j$.

  8. pre_slice: A two dimensional matrix with the shape of (pre_num, 2) stores the start and end positions on the synapse state for each connected pre-synaptic neuron $i$ .

  9. post_slice: A two dimensional matrix with the shape of (post_num, 2) stores the start and end positions on the synapse state for each connected post-synaptic neuron $j$ .

We illustrate their efficiency by a spare randomly connected E/I balance network COBA [1]. We summarize their speed in the following comparison figure:

names = ['mat 1', 'mat 2', 'pre_ids', 'pre2syn', 'pre2post', 'post_slice', 'pre_slice']
times = [t_syn_mat1, t_syn_mat2, t_syn_vec1, t_syn_vec2, t_syn_vec3, t_syn_vec4, t_syn_vec5]
xs = list(range(len(times)))
def autolabel(rects):
  """Attach a text label above each bar in *rects*, displaying its height."""
  for rect in rects:
    height = rect.get_height()
    ax.annotate(f'{height:.3f}',
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 0.5),  # 3 points vertical offset
                textcoords="offset points",
                ha='center', va='bottom')


fig, gs = bp.visualize.get_figure(1, 1, 5, 8)

ax = fig.add_subplot(gs[0, 0])
rects = ax.bar(xs, times)
ax.set_xticks(xs)
ax.set_xticklabels(names)
ax.set_yscale('log')
plt.ylabel('Running Time [s]')
autolabel(rects)
_images/0e1c87f189179aa32768eeb74767dd2435d0890e9c1d9f8e7db972d241180d20.png

However, the speed comparison presented here does not mean that the vector-based connection is always better than the matrix-based connection. Vector-based synaptic model is well suitable to simulate spare connections. Whereas the matrix-based synaptic model is best to solve problems for dense connections, such like all-to-all connection.

Synaptic computing in JAX backend

The above examples are all illustrated under the ‘numpy’ backend in BrainPy. However, the JIT compilation in ‘jax’ backend is much more constrained. Specifically, JAX transformations only support tensors. Therefore, for your sparse synaptic connections, we highly recommend you to use pre_ids and post_ids connectivities (see the above illustration).

The core problem of synaptic computation on JAX backend is how to convert values among different shape of tensors. Specifically, in the above exponential synapse model, we have three kinds of tensor shapes (see the following figure): tensors with the dimension of pre-synaptic group, tensors of the dimension of post-synaptic group, and tensors with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.

Here BrainPy provides two operators brainpy.math.pre2syn and brainpy.math.syn2post to convert vectors among different dimensions.

  • brainpy.math.pre2syn() receives two arguments: pre_values (the variable of the pre-synaptic dimension) and pre_ids (the connected pre-synaptic neuron index).

  • brainpy.math.syn2post() receives three arguments: syn_values (the variable with the synaptic size), post_ids (the connected post-synaptic neuron index) and post_num (the number of the post-synaptic neurons).

Based on these two operators, we can define the Exponential Synapse in JAX backend as:

class SynVecJax(bp.TwoEndConn):
  target_backend = 'jax'  
    
  def __init__(self, pre, post, conn, tau, weight, E, **kwargs):
    super(SynVec1, self).__init__(pre=pre, post=post, **kwargs)

    # parameters
    self.tau = tau
    self.weight = weight
    self.E = E

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.post_ids = self.conn.requires('pre_ids', 'post_ids')
    self.num = len(self.pre_ids)

    # variables
    self.g = bm.Variable(bm.zeros(self.num))
    
    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, g, t):
    dg = - g / self.tau
    return dg

  def update(self, _t, _dt):
    syn_sps = bm.pre2syn(self.pre.spike, self.pre_ids)
    self.g[:] = self.integral(self.g, _t) + syn_sps
    post_g = bm.syn2post(self.g, self.post_ids, self.post.num)
    self.post.input += post_g * (self.E - self.post.V)

More synapse models please see BrainModels.


References:

[1] Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95

Synaptic Connectivity

Contents

  • Build-in regular connections

  • Build-in random connections

  • Customize your connections

BrainPy provides several commonly used connection methods in brainpy.connect module (see the follows). They are all inherited from the base class brainpy.connect.Connector. Users can also customize their synaptic connectivity by the class inheritance.

import brainpy as bp

import numpy as np
import matplotlib.pyplot as plt

Build-in regular connections

brainpy.connect.One2One

The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
brainpy.connect.All2All

All neurons of the post-synaptic population form connections with all neurons of the pre-synaptic population (dense connectivity). Users can choose whether connect the neurons at the same position (include_self=True or False).

conn = bp.connect.All2All(include_self=False)
brainpy.connect.GridFour

GridFour is the four nearest neighbors connection. Each neuron connect to its nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
brainpy.connect.GridEight

GridEight is eight nearest neighbors connection. Each neuron connect to its nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
brainpy.connect.GridN

GridN is also a nearest neighbors connection. Each neuron connect to its nearest $2N \cdot 2N$ neurons.

conn = bp.connect.GridN(N=2, include_self=False)

Build-in random connections

brainpy.connect.FixedProb

For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=1234)
brainpy.connect.FixedPreNum

Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

conn = bp.connect.FixedPreNum(num=10, include_self=True, seed=1234)
brainpy.connect.FixedPostNum

Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

conn = bp.connect.FixedPostNum(num=10, include_self=True, seed=1234)
brainpy.connect.GaussianProb

Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.

Specifically,

$$ p=\exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right) $$

where $(x, y)$ is the position of the pre-synaptic neuron and $(x_c,y_c)$ is the position of the post-synaptic neuron.

For example, in a $30 \textrm{x} 30$ two-dimensional networks, when $\beta = \frac{1}{2\sigma^2} = 0.1$, the connection pattern is shown as the follows:

conn = bp.connect.GaussianProb(sigma=0.2, p_min=0.01, normalize=True, include_self=True, seed=1234)
brainpy.connect.GaussianWeight

Builds a Gaussian connection pattern between the two populations, where the weights decay with gaussian function.

Specifically,

$$w(x, y) = w_{max} \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right)$$

where $(x, y)$ is the position of the pre-synaptic neuron (normalized to [0,1]) and $(x_c,y_c)$ is the position of the post-synaptic neuron (normalized to [0,1]), $w_{max}$ is the maximum weight. In order to void creating useless synapses, $w_{min}$ can be set to restrict the creation of synapses to the cases where the value of the weight would be superior to $w_{min}$. Default is $0.01 w_{max}$.

def show_weight(pre_ids, post_ids, weights, geometry, neu_id):
    height, width = geometry
    ids = np.where(pre_ids == neu_id)[0]
    post_ids = post_ids[ids]
    weights = weights[ids]

    X, Y = np.arange(height), np.arange(width)
    X, Y = np.meshgrid(X, Y)
    Z = np.zeros(geometry)
    for id_, weight in zip(post_ids, weights):
        h, w = id_ // width, id_ % width
        Z[h, w] = weight

    fig = plt.figure()
    ax = fig.gca(projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap=plt.cm.coolwarm, linewidth=0, antialiased=False)
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.show()
conn = bp.connect.GaussianWeight(sigma=0.1, w_max=1., w_min=0.01,
                                 normalize=True, include_self=True)
pre_geom = post_geom = (40, 40)
conn(pre_geom, post_geom)

pre_ids = conn.pre_ids
post_ids = conn.post_ids
weights = conn.weights
show_weight(pre_ids, post_ids, weights, pre_geom, 820)
_images/037c8bfef7ce2967bb8060d8226a574e1d857af28411a4b1d9dfef5514d56a67.png
brainpy.connect.DOG

Builds a Difference-Of-Gaussian (dog) connection pattern between the two populations.

Mathematically,

$$ w(x, y) = w_{max}^+ \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma_+^2}\right) - w_{max}^- \cdot \exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma_-^2}\right) $$

where weights smaller than $0.01 * abs(w_{max} - w_{min})$ are not created and self-connections are avoided by default (parameter allow_self_connections).

dog = bp.connect.DOG(sigmas=(0.08, 0.15), ws_max=(1.0, 0.7), w_min=0.01,
                     normalize=True, include_self=True)
h = 40
pre_geom = post_geom = (h, h)
dog(pre_geom, post_geom)

pre_ids = dog.pre_ids
post_ids = dog.post_ids
weights = dog.weights
show_weight(pre_ids, post_ids, weights, (h, h), h * h // 2 + h // 2)
_images/46764d581e354eaaed3728438ebf771aae9c3011fedc452a71e487400d76842a.png
brainpy.connect.SmallWorld

SmallWorld is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:

$$ L\propto \log N $$

[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.

Currently, SmallWorld only support a one-dimensional network with the ring structure. It receives four settings:

  • num_neighbor: the number of the nearest neighbors to connect.

  • prob: the probability of rewiring each edge.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • include_self: whether allow to connect to itself.

conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
brainpy.connect.ScaleFreeBA

ScaleFreeBA is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA receives the following settings:

  • m: Number of edges to attach from a new node to existing nodes.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.

conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
brainpy.connect.ScaleFreeBADual

ScaleFreeBADual is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:

  • p: The probability of attaching $m_1$ edges (as opposed to $m_2$ edges).

  • m1 : Number of edges to attach from a new node to existing nodes with probability $p$.

  • m2: Number of edges to attach from a new node to existing nodes with probability $1-p$.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.

conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
brainpy.connect.PowerLaw

PowerLaw is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:

  • m : the number of random edges to add for each new node

  • p : Probability of adding a triangle after adding a random edge

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed : Indicator of random number generation state.

[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.

conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)

Customize your connections

BrainPy also allows you to customize your model connections. What need users do is only two aspects:

  • Your connection class should inherit brainpy.connect.Connector.

  • Initialize the conn_mat or pre_ids+ post_ids synaptic structures.

  • Provide num_pre and num_post information.

In such a way, based on this customized connection class, users can generate any other synaptic structures (such like pre2post, pre2syn, pre_slice_syn, etc.) easily.

Here, let’s take a simple connection as an example. In this example, we create a connection method which receives users’ handful index projection.

class IndexConn(bp.connect.Connector):
    def __init__(self, i, j):
        super(IndexConn, self).__init__()
        
        # initialize the class via "pre_ids" and "post_ids"
        self.pre_ids = bp.ops.as_tensor(i)
        self.post_ids = bp.ops.as_tensor(j)
    
    def __call__(self, pre_size, post_size):
        self.num_pre = bp.size2len(pre_size)  # this is ncessary when create "pre2post" , 
                                              # "pre2syn"  etc. structures
        self.num_post = bp.size2len(post_size) # this is ncessary when create "post2pre" , 
                                               # "post2syn"  etc. structures
        return self

Let’s try to use it.

conn = IndexConn(i=[0, 1, 2], j=[0, 0, 0])
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
array([[1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])
conn.requires('pre2post')
[array([0]),
 array([0]),
 array([0]),
 array([], dtype=int32),
 array([], dtype=int32)]
conn.requires('pre2syn')
[array([0]),
 array([1]),
 array([2]),
 array([], dtype=int32),
 array([], dtype=int32)]
conn.requires('pre_slice_syn')
array([[0, 1],
       [1, 2],
       [2, 3],
       [3, 3],
       [3, 3]])

  • Chaoming Wang (adaduo@outlook.com)

  • Update at 2021.04.16


Monitors and Inputs

@Chaoming Wang

BrainPy has a systematic naming system. Any model in BrainPy have a unique name. Thus, nodes, integrators, and variables can be easily accessed in a huge network. Based on this naming system, BrainPy provides a set of convenient monitoring and input supports. In this section, we are going to talk about this.

import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt

Monitors

In BrainPy, any instance of brainpy.DynamicalSystem has a build-in monitor. Users can set up the monitor when initializing the brain object. For example, if you have the following HH neuron model,

class HH(bp.NeuGroup):
  def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0, 
               gNa=120., gK=36., gL=0.03, V_th=20., **kwargs):
    super(HH, self).__init__(size=size, **kwargs)

    # parameters
    self.ENa = ENa
    self.EK = EK
    self.EL = EL
    self.C = C
    self.gNa = gNa
    self.gK = gK
    self.gL = gL
    self.V_th = V_th

    # variables
    self.V = bm.Variable(bm.ones(self.num) * -65.)
    self.m = bm.Variable(bm.ones(self.num) * 0.5)
    self.h = bm.Variable(bm.ones(self.num) * 0.6)
    self.n = bm.Variable(bm.ones(self.num) * 0.32)
    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    
    # functions
    self.integral = bp.odeint(self.derivative, method='exponential_euler')

  def derivative(self, V, m, h, n, t, Iext):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m

    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h

    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n

    I_Na = (self.gNa * m ** 3 * h) * (V - self.ENa)
    I_K = (self.gK * n ** 4) * (V - self.EK)
    I_leak = self.gL * (V - self.EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / self.C

    return dVdt, dmdt, dhdt, dndt

  def update(self, _t, _dt):
    V, m, h, n = self.integral(self.V, self.m, self.h, self.n, _t, self.input)
    self.spike[:] = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.V[:] = V
    self.m[:] = m
    self.h[:] = h
    self.n[:] = n
    self.input[:] = 0

The monitor can be set up when users create a HH neuron group.

First method is to initialize a monitor is using a list/tuple of strings.

# set up a monitor using a list of str
group1 = HH(size=10, monitors=['V', 'spike'])

type(group1.mon)
brainpy.simulation.monitor.Monitor

The initialized monitor is an instance of brainpy.Monitor. Therefore, users can also directly use Monitor class to initialize a monitor.

# set up a monitor using brainpy.Monitor
group2 = HH(size=10, monitors=bp.Monitor(variables=['V', 'spike']))

Once we call the .run() function in the model, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using [model_name].mon.[variable_name]. The history time [model_name].mon.ts will also be generated after the model finishes its running. Let’s see an example.

group1.run(100., inputs=('input', 10))

bp.visualize.line_plot(group1.mon.ts, group1.mon.V, show=True)
_images/c40347a8a6f576dafb217253946c497082703a31c2ebcc6eeabd6031479f34af.png

The monitor in group1 has recorded the evolution of V. Therefore, it can be accessed by group1.mon.V or equivalently group1.mon['V']. Similarly, the recorded trajectory of variable spike can also be obtained through group1.mon.spike.

group1.mon.spike
array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])
The mechanism of monitors

We want to record HH.V and HH.spike, why we define monitors=['V', 'spike'] during HH initialization is successful? How does brainpy.Monitor recognize what variables I want to trace?

Actually, given the monitor targets, BrainPy, first of all, check whether this key is the attribute of the node which defines this monitor key. For monitor targets 'V' and 'spike', it is really the attributes of HH model. However, if not, BrainPy first check whether the key’s host (brainpy.DynamicalSystem class object) can be accessible in .nodes(), then check whether the host has the specified variable. For example, we define a network, and define the monitor target by the absolute path.

net = bp.Network(HH(size=10, name='X'), 
                 HH(size=20, name='Y'), 
                 HH(size=30), 
                 monitors=['X.V', 'Y.spike'])

net.build()  # it's ok

In the above net, there are HH instances named as “X” and “Y”. Therefore, trying to monitor “X.V” and “Y.spike” is successful.

However, in the following example, node named with “Z” is not accessible in the generated net. Therefore the monitoring setup failed.

z = HH(size=30, name='Z')
net = bp.Network(HH(size=10), HH(size=20), monitors=['Z.V'])

# node "Z" can not be accessed in 'net.nodes()'
try:
    net.build()
except Exception as e:
    print(type(e).__name__, ":", e)
BrainPyError : Cannot find target Z.V in monitor of <brainpy.simulation.brainobjects.network.Network object at 0x7fc14c454610>, please check.

Note

BrainPy only supports to monitor Variable. This is because BrainPy assumes monitoring Variable’s trajectory is meaningful, because they are dynamically changed, and others not marked as Variable will be compiled as constants.

try:
    HH(size=1, monitors=['gNa']).build()
except Exception as e:
    print(type(e).__name__, ":", e)
BrainPyError : "gNa" in <__main__.HH object at 0x7fc14c3d1f40> is not a dynamically changed Variable, its value will not change, we cannot monitor its trajectory.

Note

The monitors in BrainPy only record the flattened tensor values. This means if your target variable is a matrix with the shape of (N, M), the resulting trajectory value in the monitor after running T times will be a tensor with the shape of (T, N x M).

class MatrixVarModel(bp.DynamicalSystem):
    def __init__(self, **kwargs):
        super(MatrixVarModel, self).__init__(**kwargs)
        
        self.a = bm.Variable(bm.zeros((4, 4)))
    
    def update(self, _t, _dt):
        self.a += 0.01
        

duration = 10
model = MatrixVarModel(monitors=['a'])
model.run(duration)

print(f'The expected shape of "model.mon.a" is: {(int(duration/bm.get_dt()), model.a.size)}')
print(f'The actual shape of "model.mon.a" is: {model.mon.a.shape}')
The expected shape of "model.mon.a" is: (100, 16)
The actual shape of "model.mon.a" is: (100, 16)
Monitor variables at the selected index

Sometimes, we do not always take care of the all the content in a variable. We may be only interested in the values at the selected index. Moreover, for a huge network with a long time simulation, monitors will be a big part to consume RAM. So, only monitoring variables at the selected index will be a good solution. Fortunately, BrainPy supports to monitor a part of elements in a Variable with the format of tuple/dict like this:

group3 = HH(
    size=10,
    monitors=[
       'V',  # monitor all values of Variable 'V' 
      ('spike', [1, 2, 3]), # monitor values of Variable at index of [1, 2, 3]
    ]
)

group3.run(100., inputs=('input', 10.))

print(f'The monitor shape of "V" is (run length, variable size) = {group3.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group3.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)

Or, we can use a dictionary to specify the interested index of the variable:

group4 = HH(
  size=10,
  monitors={'V': None,  # 'None' means all values will be monitored
            'spike': [1, 2, 3]}  # specify the interested index 
)

group4.run(100., inputs=('input', 10.))

print(f'The monitor shape of "V" is (run length, variable size) = {group4.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group4.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)

Also, we can directly instantiate brainpy.Monitor class:

group5 = HH(
  size=10,
  monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])])
)
group5.run(100., inputs=('input', 10.))

print(f'The monitor shape of "V" is (run length, variable size) = {group5.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group5.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)
group6 = HH(
  size=10,
  monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]})
)
group6.run(100., inputs=('input', 10.))

print(f'The monitor shape of "V" is (run length, variable size) = {group5.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {group5.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (1000, 10)
The monitor shape of "spike" is (run length, index size) = (1000, 3)

Note

When users want to record a small part of a variable whose dimension > 1, due to brainpy.Monitor records a flattened tensor variable, they must provide the index positions at the flattened tensor.

Monitor variables with a customized period

In a long simulation with a small time step dt , what we take care about is the trend of the variable evolution, not the exact values at each time point (especially when dt is very small). For this scenario, we can initialize the monitors with the intervals item specification:

group7 = HH(
  size=10,
  monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]},
                      intervals={'V': None, 'spike': 1.})  # in 1 ms, we record 'spike' only once
)

# The above instantiation is equivalent to:
# 
# group7 = HH(
#   size=10, monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])],
#                                intervals=[None, 1.])
# )

In this example, we monitor “spike” variables at the index of [1, 2, 3] for each 1 ms.

group7.run(100., inputs=('input', 10.))

print(f'The monitor shape of "V" = {group7.mon.V.shape}')
print(f'The monitor shape of "spike" = {group7.mon.spike.shape}')
The monitor shape of "V" = (1000, 10)
The monitor shape of "spike" = (99, 3)

It’s worthy to note that for the monitor variable [variable_name] with a non-none intervals specification, a corresponding time item [variable_name].t will be generated in the monitor. This is because it’s time trajectory will be different from the default time trajectory.

print('The shape of ["spike"]: ', group7.mon['spike'].shape)
print('The shape of ["spike.t"]: ', group7.mon['spike.t'].shape)

print('group7.mon["spike.t"]: ', group7.mon["spike.t"])
The shape of ["spike"]:  (99, 3)
The shape of ["spike.t"]:  (99,)
group7.mon["spike.t"]:  [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. 89. 90.
 91. 92. 93. 94. 95. 96. 97. 98. 99.]

Inputs

BrainPy also provides inputs operation for each instance of brainpy.DynamicalSystem. It should be carried out during calling the .run(..., inputs=xxx) function.

The aim of inputs is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording. inputs should have the format like (target, value, [type, operation]), where

  • target is the target variable to inject the input.

  • value is the input value. It can be a scalar, a tensor, or a iterable object/function.

  • type is the type of the input value. It support two types of input: fix and iter.

  • operation is the input operation on the target variable. It should be set as one of { + , - , * , / , = }, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated as val = val + input.

You can also give multiple inputs for different target variables, like:


inputs=[(target1, value1, [type1, op1]),  
        (target2, value2, [type2, op2]),
              ... ]
The mechanism of inputs

The mechanism of inputs is the same with monitors (see The mechanism of monitors). BrainPy first check whether user specified target can be accessed by the relative path.

If not, BrainPy separate the host name and the variable name, and further check whether the host name is defined in the .node() and whether the variable name can be accessed by the retrieved host. Therefore, in a input setting, the target can be set with the absolute or relative path. For example, in the below network model,

class Model(bp.DynamicalSystem):
    def __init__(self, num_sizes, **kwargs):
        super(Model, self).__init__(**kwargs)
        
        self.l1 = HH(num_sizes[0], name='L')
        self.l2 = HH(num_sizes[1])
        self.l3 = HH(num_sizes[2])
        
    def update(self, _t, _dt):
        self.l1.update(_t, _dt)
        self.l2.update(_t, _dt)
        self.l3.update(_t, _dt)
model = Model([10, 20, 30])

model.run(100, inputs=[('L.V', 2.0),  # access with the absolute path
                       ('l2.V', 1),  # access with the relative path
                       ])
0.7689826488494873

inputs supports two types of data: fix and iter. The first one means that the data is static; the second one denotes the data can be iterable, no matter the input value is a tensor or a function. Note, ‘iter’ type must be explicitly stated.

# a tensor

model.run(100, inputs=('L.V', bm.ones(1000) * 2., 'iter'))
0.7576150894165039
# a function

def current():
    while True: yield 2.

model.run(100, inputs=('L.V', current(), 'iter'))
0.767667293548584
Current construction functions

Inputs are common in a computational experiment. Also, we need various kind of inputs. In BrainPy, we provide several convenient input functions to help users construct input currents.

section_input()

brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input() (see below).

Sometimes, we need input currents with different values in different periods. For example, if you want to get an input in which 0-100 ms is zero, 100-400 ms is value 1., and 400-500 ms is zero, then, you can define:

current, duration = bp.inputs.section_input(values=[0, 1., 0.],
                                            durations=[100, 300, 100],
                                            return_length=True)
def show(current, duration, title):
    ts = np.arange(0, duration, 0.1)
    plt.plot(ts, current)
    plt.title(title)
    plt.xlabel('Time [ms]')
    plt.ylabel('Current Value')
    plt.show()
show(current, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')
_images/2a4626e4158231728bd4ef44ff5037fbbcbe28a1b0a2d64a7aa0751bc988b849.png
constant_input()

brainpy.inputs.constant_input() function helps you to format constant currents in several periods.

For the input created above, we can define it again with constant_input() by:

current, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])
show(current, duration, '[(0, 100), (1, 300), (0, 100)]')
_images/8a52a2cbf57e5aca4fd949c2753e7fad46d74b4b796d12ba5b6dbef7492096ef.png

Another example is this:

current, duration = bp.inputs.constant_input([(-1, 10), (1, 3), (3, 30), (-0.5, 10)], dt=0.1)
show(current, duration, '[(-1, 10), (1, 3), (3, 30), (-0.5, 10)]')
_images/b8afb03c53f80c7653f01ee2ff07700031e071fe51ac12e9ea68d0b44347b4c1.png
spike_input()

brainpy.inputs.spike_input() helps you to construct an input like a series of short-time spikes. It receives the following settings:

  • sp_times : The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.

  • sp_lens : The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same with sp_times.

  • sp_sizes : The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same with sp_times.

  • duration : The total current duration.

  • dt : The time step precision. The default is None (will be initialized as the default dt step).

For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, and each spike lasts 1 ms and the spike current is 0.5, then you can use the following funtions:

current = bp.inputs.spike_input(
    sp_times=[10, 20, 30, 200, 300],
    sp_lens=1.,  # can be a list to specify the spike length at each point
    sp_sizes=0.5,  # can be a list to specify the spike current size at each point
    duration=400.)
show(current, 400, 'Spike Input Example')
_images/65e5761f882636acda37847b8c7173617a64afa1bea226254223be919117d414.png
ramp_input()

brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:

  • c_start : The minimum (or maximum) current size.

  • c_end : The maximum (or minimum) current size.

  • duration : The total duration.

  • t_start : The ramped current start time-point.

  • t_end : The ramped current end time-point. Default is the None.

  • dt : The current precision.

We illustrate the usage of brainpy.inputs.ramp_input() by two examples.

In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (1000 ms).

duration = 1000
current = bp.inputs.ramp_input(0, 1, duration)

show(current, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
                        r'$t_{start}$=0, $t_{end}$=None' % (duration))
_images/8fd9d8c49ffc0209028bba3b7881b8ea383ae5fe26ae9a2dff25b9c8326c87fe.png

In the second example, we increase the current size from 0. to 1. from the 200 ms to 800 ms.

duration, t_start, t_end = 1000, 200, 800
current = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)

show(current, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
                        r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))
_images/6051b205a260a52ecd12de6177c4808896effd103fe37ecf5388f099b9a8ed91.png
General property of current functions

There are several general properties for input construction functions.

Property 1: All input functions can automatically broadcast the current shapes, if they are heterogenous among different periods. For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example,

current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
                                  durations=[100, 300, 100])

current.shape
(5000, 3, 10)

Property 2: Every input function receives a dt specification. If dt is not provided, input functions will use the default dt in the whole BrainPy system.

bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.02).shape
(3000,)
bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.2).shape
(300,)
# the default 'dt' in 0.1

bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30]).shape
(600,)

Dynamics Training

Build Artificial Neural Networks

Artificial neural networks in BrainPy are used to build dynamical systems. Here we only talk about how to build a neural network and how to train it.

The brainpy.simulation.layers module provides various classes representing the layers of a neural network. All of them are subclasses of the brainpy.simulation.layers.Module base class.

import brainpy as bp
bp.set_platform('cpu')

import brainpy.simulation.layers as nn
import brainpy.math.jax as bm
bp.math.use_backend('jax')

Creating a layer

A layer can be created as an instance of a brainpy.layers.Module subclass. For example, a dense layer can be created as follows:

l = nn.Dense(num_hidden=100, num_input=128) 
type(l)
brainpy.simulation.layers.dense.Dense

This will create a dense layer with 100 units, connected to another input layer with 128 dimension.

Creating a network

Chaining layer instances together like this will allow you to specify your desired network structure.

This can be done with inheritance from brainpy.layers.Module,

class MLP(nn.Module):
    def __init__(self, n_in, n_l1, n_l2, n_out):
        super(MLP, self).__init__()
        
        self.l1 = nn.Dense(num_hidden=n_l1, num_input=n_in)
        self.l2 = nn.Dense(num_hidden=n_l2, num_input=n_l1)
        self.l3 = nn.Dense(num_hidden=n_out, num_input=n_l2)
        
    def update(self, x):
        x = bm.relu(self.l1(x))
        x = bm.relu(self.l2(x))
        x = self.l3(x)
        return x
mlp1 = MLP(10, 50, 100, 2)

Or using brainpy.layers.Sequential,

mlp2 = nn.Sequential(
    l1=nn.Dense(num_hidden=50, num_input=10),
    r1=nn.Activation('relu'), 
    l2=nn.Dense(num_hidden=100, num_input=50),
    r2=nn.Activation('relu'), 
    l3=nn.Dense(num_hidden=2, num_input=100),
)

Naming a layer

For convenience, you can name a layer by specifying the name keyword argument:

l_hidden = nn.Dense(num_hidden=50, num_input=10, name='hidden_layer')

Initializing parameters

Many types of layers, such as brainpy.layers.Dense, have trainable parameters. These are referred to by short names that match the conventions used in modern deep learning literature. For example, a weight matrix will usually be called w, and a bias vector will usually be b.

When creating a layer with trainable parameters, TrainVar will be created for them and initialized automatically. You can optionally specify your own initialization strategy by using keyword arguments that match the parameter variable names. For example:

l = nn.Dense(num_hidden=50, num_input=10, w=bp.initialize.Normal(0.01))

The weight matrix w of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).

There are several ways to manually initialize parameters:

  • Tensors

If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:

w = bm.random.normal(0, 0.01, size=(10, 50))
nn.Dense(num_hidden=50, num_input=10, w=w)
<brainpy.simulation.layers.dense.Dense at 0x23cff9bb910>
  • callable

If a callable is provided (e.g. a function or a brainpy.initialize.Initializer instance), the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:

nn.Dense(num_hidden=50, num_input=10, w=bp.initialize.Normal(0.01))
<brainpy.simulation.layers.dense.Dense at 0x23cff9bf2b0>

Or, using a custom initialization function:

def init_w(shape):
    return bm.random.normal(0, 0.01, shape)

nn.Dense(num_hidden=50, num_input=10, w=init_w)
<brainpy.simulation.layers.dense.Dense at 0x23cff9ac670>

Some types of parameter variables can also be set to None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:

nn.Dense(num_hidden=50, num_input=10, b=None)
<brainpy.simulation.layers.dense.Dense at 0x23cff99fa30>

Setup a training

Here, we show an example to train MLP to classify the MNIST images.

import numpy as np
import tensorflow as tf

# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
num_train, num_test = X_train.shape[0], X_test.shape[0]
num_dim = bp.tools.size2num(X_train.shape[1:])
X_train = np.asarray(X_train.reshape((num_train, num_dim)) / 255.0, dtype=bm.float_)
X_test = np.asarray(X_test.reshape((num_test, num_dim)) / 255.0, dtype=bm.float_)
Y_train = np.asarray(Y_train.flatten(), dtype=bm.float_)
Y_test = np.asarray(Y_test.flatten(), dtype=bm.float_)
model = MLP(n_in=num_dim, n_l1=256, n_l2=128, n_out=10)
opt = bm.optimizers.Momentum(lr=1e-3, train_vars=model.train_vars())
gv = bm.grad(lambda X, Y: bm.losses.cross_entropy_loss(model(X), Y),
             dyn_vars=model.vars(),
             grad_vars=model.train_vars(),
             return_value=True)
@bm.jit
@bm.function(nodes=(model, opt))
def train(x, y):
    grads, loss = gv(x, y)
    opt.update(grads=grads)
    return loss
predict = bm.jit(lambda X: bm.softmax(model(X)), dyn_vars=model.vars())
# Training
num_batch = 128
for epoch in range(30):
  # Train
  loss = []
  sel = np.arange(len(X_train))
  np.random.shuffle(sel)
  for it in range(0, X_train.shape[0], num_batch):
    l = train(X_train[sel[it:it + num_batch]], Y_train[sel[it:it + num_batch]])
    loss.append(l)

  # Eval
  test_predictions = predict(X_test).argmax(1)
  accuracy = np.array(test_predictions).flatten() == Y_test
  print(f'Epoch {epoch + 1:4d}  Train Loss {np.mean(loss):.3f}  Test Accuracy {100 * np.mean(accuracy):.3f}')
Epoch    1  Train Loss 1.212  Test Accuracy 86.410
Epoch    2  Train Loss 0.467  Test Accuracy 89.810
Epoch    3  Train Loss 0.367  Test Accuracy 90.670
Epoch    4  Train Loss 0.325  Test Accuracy 91.470
Epoch    5  Train Loss 0.298  Test Accuracy 92.220
Epoch    6  Train Loss 0.278  Test Accuracy 92.890
Epoch    7  Train Loss 0.261  Test Accuracy 93.140
Epoch    8  Train Loss 0.248  Test Accuracy 93.530
Epoch    9  Train Loss 0.235  Test Accuracy 93.810
Epoch   10  Train Loss 0.224  Test Accuracy 94.010
Epoch   11  Train Loss 0.214  Test Accuracy 94.100
Epoch   12  Train Loss 0.205  Test Accuracy 94.350
Epoch   13  Train Loss 0.196  Test Accuracy 94.540
Epoch   14  Train Loss 0.189  Test Accuracy 94.680
Epoch   15  Train Loss 0.182  Test Accuracy 94.910
Epoch   16  Train Loss 0.175  Test Accuracy 95.070
Epoch   17  Train Loss 0.169  Test Accuracy 95.190
Epoch   18  Train Loss 0.163  Test Accuracy 95.280
Epoch   19  Train Loss 0.157  Test Accuracy 95.410
Epoch   20  Train Loss 0.153  Test Accuracy 95.570
Epoch   21  Train Loss 0.148  Test Accuracy 95.760
Epoch   22  Train Loss 0.143  Test Accuracy 95.760
Epoch   23  Train Loss 0.139  Test Accuracy 95.930
Epoch   24  Train Loss 0.135  Test Accuracy 95.910
Epoch   25  Train Loss 0.131  Test Accuracy 96.150
Epoch   26  Train Loss 0.128  Test Accuracy 96.110
Epoch   27  Train Loss 0.124  Test Accuracy 96.310
Epoch   28  Train Loss 0.121  Test Accuracy 96.330
Epoch   29  Train Loss 0.118  Test Accuracy 96.410
Epoch   30  Train Loss 0.115  Test Accuracy 96.410

Creat Custom Layers

To implement a custom layer in BrainPy, you will have to write a Python class that subclasses brainpy.simulation.layers.Module and implement at least one method: update(). This method computes the output of the module given its input.

import brainpy as bp
bp.set_platform('cpu')

import brainpy.simulation.layers as nn
import brainpy.math.jax as bm
bp.math.use_backend('jax')

The following is an example implementation of a layer that multiplies its input by 2:

class DoubleLayer(nn.Module):
    def update(self, x):
        return 2 * x

This is all that’s required to implement a functioning custom module class in BrainPy.

A layer with parameters

If the layer has parameters, these should be initialized in the constructor. In BrainPy, we recommend you to mark parameters as brainpy.math.TrainVar.

To show how this can be used, here is a layer that multiplies its input by a matrix W (much like a typical fully connected layer in a neural network would). This matrix is a parameter of the layer. The shape of the matrix will be (num_input, num_hidden), where num_input is the number of input features and num_hidden has to be specified when the layer is created.

class DotLayer(nn.Module):
    def __init__(self, num_input, num_hidden, W=bp.initialize.Normal(), **kwargs):
        super(DotLayer, self).__init__(**kwargs)
        self.num_input = num_input
        self.num_hidden = num_hidden
        self.W = bm.TrainVar(W([num_input, num_hidden]))

    def update(self, x):
        return bm.dot(x, self.W)

A few things are worth noting here: when overriding the constructor, we need to call the superclass constructor on the first line. This is important to ensure the layer functions properly. Note that we pass **kwargs - although this is not strictly necessary, it enables some other cool features, such as making it possible to give the layer a name:

l_dot = DotLayer(10, 50, name='my_dot_layer')

A layer with multiple behaviors

Some layers can have multiple behaviors. For example, a layer implementing dropout should be able to be switched on or off. During training, we want it to apply dropout noise to its input and scale up the remaining values, but during evaluation we don’t want it to do anything.

For this purpose, the update() method takes optional keyword arguments (kwargs). When update() is called to compute an expression for the output of a network, all specified keyword arguments are passed to the update() methods of all layers in the network.

class Dropout(nn.Module):
    def __init__(self, prob, seed=None, **kwargs):
        super(Dropout, self).__init__(**kwargs)
        self.prob = prob
        self.rng = bm.random.RandomState(seed=seed)

    def update(self, x, **kwargs):
        if kwargs.get('train', True):
            keep_mask = self.rng.bernoulli(self.prob, x.shape)
            return bm.where(keep_mask, x / self.prob, 0.)
        else:
            return x

Dynamics Analysis

Dynamics Analysis (Symbolic)

@Chaoming Wang

As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. For example, attractors and repellors can be easily obtained through simulation by time forward and backward, while saddles can be hard to find.

Currently, BrainPy supports two kinds of analysis methods (see brainpy.analyis documents):

The first class of analysis method supports neurodynamics analysis for low-dimensional dynamical systems. Specifically, BrainPy provides the following methods for dynamics analysis:

  1. phase plane analysis for one-dimensional and two-dimensional systems;

  2. codimension one and codimension two bifurcation analysis;

  3. bifurcation analysis of the fast-slow system.

In this section, I will illustrate how to do neuron dynamics analysis in BrainPy and how BrainPy implements it.

import brainpy as bp

bp.__version__
'1.1.0'

Phase Plane Analysis

We provide a fundamental class PhasePlane to help users make phase plane analysis for 1D/2D dynamical systems. Five methods are provided, which can help you to plot:

  • Fixed points

  • Nullcline (zero-growth isoclines)

  • Vector filed

  • Limit cycles

  • Trajectory

Here, I will illustrate how to do phase plane analysis by using a well-known neuron model FitzHugh-Nagumo model.

FitzHugh-Nagumo model

The FitzHugh-Nagumo model is given by:

$$ \frac {dV} {dt} = V(1 - \frac {V^2} 3) - w + I_{ext} \ \tau \frac {dw} {dt} = V + a - b w $$

There are two variables $V$ and $w$, so this is a two-dimensional system with three parameters $a, b$ and $\tau$.

a = 0.7
b = 0.8
tau = 12.5
Vth = 1.9


@bp.odeint
def int_fhn(V, w, t, Iext):
  dw = (V + a - b * w) / tau
  dV = V - V * V * V / 3 - w + Iext
  return dV, dw

Phase Plane Analysis is implemented in brainpy.sym_analysis.PhasePlane. It receives the following parameters:

  • integrals: The integral functions or instance of brainpy.DynamicalSystem are going to be analyzed.

  • target_vars: The variables to be analuzed. It must a dictionary with the format of {var: variable range}.

  • fixed_vars: The variables to be fixed (optional).

  • pars_update: Parameters to update (optional).

brainpy.analysis.PhasePlane provides interface to analyze the system’s

  • nullcline: The zero-growth isoclines, such as $g(x, y)=0$ and $g(x, y)=0$.

  • fixed points: The equilibrium points of the system, which are located at all of the nullclines intersect.

  • vector filed: The vector field of the system.

  • Trajectory: A given simulation trajectory with the fixed variables.

  • Limit cycles: The limit cycles.

Here we perform a phase plane analysis with parameters $a=0.7, b=0.8, \tau=12.5$, and input $I_{ext} = 0.8$.

analyzer = bp.symbolic.PhasePlane(
  int_fhn,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  pars_update={'Iext': 0.8})
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory([{'V': -2.8, 'w': -1.8}],
                         duration=100.,
                         show=True)
_images/2b5cf437c616256ec5d73e598c5235d8278f3f3e959312a8c1de17b86b1dda5a.png

We can see an unstable-node at the point (v=-0.27, w=0.53) inside a limit cycle. Then we can run a simulation with the same parameters and initial values to see the periodic activity that correspond to the limit cycle.

class FHN(bp.NeuGroup):
  def __init__(self, num, **kwargs):
    super(FHN, self).__init__(size=num, **kwargs)
    self.V = bp.math.Variable(bp.math.ones(num) * -2.8)
    self.w = bp.math.Variable(bp.math.ones(num) * -1.8)
    self.Iext = bp.math.Variable(bp.math.zeros(num))

  def update(self, _t, _dt):
    self.V[:], self.w[:] = int_fhn(self.V, self.w, _t, self.Iext)
    self.Iext[:] = 0.


group = FHN(1, monitors=['V', 'w'])
group.run(100., inputs=('Iext', 0.8))
bp.visualize.line_plot(group.mon.ts, group.mon.V, legend='v', )
bp.visualize.line_plot(group.mon.ts, group.mon.w, legend='w', show=True)
_images/79d269ea59319568122b2e246a13d093db004647cb9fdc24da2bd33f48a1e25c.png

Note that the fixed_vars can be used to specify the neuron model’s state ST, it can also be used to specify the functional arguments in integrators (like the Iext in int_v()).

Bifurcation Analysis

Bifurcation analysis is implemented within brainpy.sym_analysis.Bifurcation. Which support codimension-1 and codimension-2 bifurcation analysis. Specifically, it receives the following parameter settings:

  • integrals: The integral functions or instance of brainpy.DynamicalSystem are going to be analyzed.

  • target_pars: The target parameters. Must be a dictionary with the format of {par: parameter range}.

  • target_vars: The target variables. Must be a dictionary with the format of {var: variable range}.

  • fixed_vars: The fixed variables.

  • pars_update: The parameters to update.

Codimension 1 bifurcation analysis

We will first see the codimension 1 bifurcation anlysis of the model. For example, we vary the input $I_{ext}$ between 0 to 1 and see how the system change it’s stability.

analyzer = bp.symbolic.Bifurcation(
  int_fhn,
  target_pars={'Iext': [0., 1.]},
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  numerical_resolution=0.001,
)
res = analyzer.plot_bifurcation(show=True)
_images/df4e0d85c272bcab7338c576131a8068fa513a226f83b54a67b9f3fa33d009c3.png _images/bbfa3a14f729b4ccc653f082e536ad827b028622791b6867970c3b07c6a6835a.png

Codimension 2 bifurcation analysis

We simulaneously change $I_{ext}$ and parameter $a$.

analyzer = bp.symbolic.Bifurcation(
  int_fhn,
  target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
  target_vars=dict(V=[-3, 3], w=[-3., 3.]),
  numerical_resolution=0.01,
)
res = analyzer.plot_bifurcation(show=True)
_images/52c7f6236571903e4c19f726fef55a75fcb0e1546c9bc8d41c8a9f1d2e4487bb.png _images/c3bb364c58afb8208661b010216073aba33d572ff98f7dc85b77e9133792ac08.png

Fast-Slow System Bifurcation

BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.sym_analysis.FastSlowBifurcation. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.

brainpy.sym_analysis.FastSlowBifurcation is very usefull in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:

$$ \begin{aligned} \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I\ \frac{d y}{d t} &= c - d V^2 - y\ \frac{d z}{d t} &= r (s (V - V_{rest}) - z) \end{aligned} $$

First of all, let’s define the Hindmarsh–Rose model with BrainPy.

a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9


@bp.odeint(method='rk4', dt=0.02)
def int_hr(x, y, z, t, Isyn):
  dx = y - a * x ** 3 + b * x * x - z + Isyn
  dy = c - d * x * x - y
  dz = r * (s * (x - x_r) - z)
  return dx, dy, dz

We now can start to analysis the underlying bifurcation mechanism.

analyzer = bp.symbolic.FastSlowBifurcation(
  int_hr,
  fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
  slow_vars={'z': [-5., 5.]},
  pars_update={'Isyn': 0.5},
  numerical_resolution=0.001
)
analyzer.plot_bifurcation()
analyzer.plot_trajectory([{'x': 1., 'y': 0., 'z': -0.0}],
                         duration=100.,
                         show=True)
_images/fe0a678e899f1a27cc14b6e9b661966166621c088cc8f81793206dd80cf6c71d.png _images/f1a2848565f916c47fb5626685f1d416bac5659363acf3e52a44a7970ea4b1b6.png

References:

[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.

[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.

[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.

Dynamics Analysis (Numeric)

brainpy.base module

The base module for whole BrainPy ecosystem.

  • This module provides the most fundamental class Base, and its associated helper class Collector and ArrayCollector.

  • For each instance of “Base” class, users can retrieve all the variables (or trainable variables), integrators, and nodes.

  • This module also provides a Function class to wrap user-defined functions. In each function, maybe several nodes are used, and users can initialize a Function by providing the nodes used in the function. Unfortunately, Function class does not have the ability to gather nodes automatically.

  • This module provides io helper functions to help users save/load model states, or share user’s customized model with others.

Details please see the following.

Base Class

Base([name])

The Base class for whole BrainPy ecosystem.

class brainpy.base.Base(name=None)[source]

The Base class for whole BrainPy ecosystem.

The subclass of Base includes:

  • DynamicalSystem in brainpy.simulation.brainobjects.base.py

  • Function in brainpy.base.function.py

  • AutoGrad in brainpy.math.jax.autograd.py

  • Optimizer in brainpy.math.jax.optimizers.py

  • Scheduler in brainpy.math.jax.optimizers.py

implicit_nodes = None

Used to wrap the implicit children nodes which cannot be accessed by self.xxx

implicit_vars = None

Used to wrap the implicit variables which cannot be accessed by self.xxx

ints(method='absolute')[source]

Collect all integrators in this node and the children nodes.

Parameters

method (str) – The method to access the integrators.

Returns

collector – The collection contained (the path, the integrator).

Return type

Collector

load_states(filename, verbose=False, check=False)[source]

Load the model states.

Parameters

filename (str) – The filename which stores the model states.

nodes(method='absolute', _paths=None)[source]

Collect all children nodes.

Parameters
  • method (str) – The method to access the nodes.

  • _paths (set, Optional) – The data structure to solve the circular reference.

Returns

gather – The collection contained (the path, the node).

Return type

Collector

save_states(filename, all_vars=None, **setting)[source]

Save the model states.

Parameters

filename (str) – The file name which to store the model states.

target_backend = None

Used to specify the target backend which the model to run.

train_vars(method='absolute')[source]

The shortcut for retrieving all trainable variables.

Parameters

method (str) – The method to access the variables. Support ‘absolute’ and ‘relative’.

Returns

gather – The collection contained (the path, the trainable variable).

Return type

TensorCollector

unique_name(name=None, type=None)[source]

Get the unique name for this object.

Parameters
  • name (str, optional) – The expected name. If None, the default unique name will be returned. Otherwise, the provided name will be checked to guarantee its uniqueness.

  • type (str, optional) – The type of this class, used for object naming.

Returns

name – The unique name for this object.

Return type

str

vars(method='absolute')[source]

Collect all variables in this node and the children nodes.

Parameters

method (str) – The method to access the variables.

Returns

gather – The collection contained (the path, the variable).

Return type

TensorCollector

Function Wrapper

Function(f[, nodes, dyn_vars, name])

The wrapper for Python functions.

class brainpy.base.Function(f, nodes=None, dyn_vars=None, name=None)[source]

The wrapper for Python functions.

Parameters
  • f (function) – The function to wrap.

  • nodes (optional, Base, sequence of Base, dict) – The nodes in the defined function f.

  • dyn_vars (optional, ndarray, sequence of ndarray, dict) – The dynamically changed variables.

  • name (optional, str) – The function name.

Collectors

Collector

A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

TensorCollector

A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

class brainpy.base.Collector[source]

A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A Collector is ordered by insertion order. It is the object returned by Base.vars() and used as input in many Collector instance: optimizers, jit, etc…

replace(key, new_value)[source]

Replace the original key with the new value.

subset(var_type, judge_func=None)[source]

Get the subset of the (key, value) pair.

subset() can be used to get a subset of some class:

>>> import brainpy as bp
>>>
>>> some_collector = Collector()
>>>
>>> # get all trainable variables
>>> some_collector.subset(bp.math.TrainVar)
>>>
>>> # get all JaxArray
>>> some_collector.subset(bp.math.Variable)

or, it can be used to get a subset of integrators:

>>> # get all ODE integrators
>>> some_collector.subset(bp.ode.ODEIntegrator)
Parameters
  • var_type (Any) – The type/class to match.

  • judge_func (optional, callable) –

unique()[source]

Get a new type of collector with unique values.

If one value is assigned to two or more keys, then only one pair of (key, value) will be returned.

update([E, ]**F) None.  Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

class brainpy.base.TensorCollector[source]

A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A Collection is ordered by insertion order. It is the object returned by DynamicalSystem.vars() and used as input in many DynamicalSystem instance: optimizers, Jit, etc…

assign(inputs)[source]

Assign data to all values.

Parameters

inputs (dict) – The data for each value in this collector.

data()[source]

Get all data in each value.

dict()[source]

Get a dict with the key and the value data.

replicate()[source]

A context manager to use in a with statement that replicates the variables in this collection to multiple devices.

Important: replicating also updates the random state in order to have a new one per device.

Exporting and Loading

save_h5(filename, all_vars)

save_npz(filename, all_vars[, compressed])

save_pkl(filename, all_vars)

save_mat(filename, all_vars)

load_h5(filename, target[, verbose, check])

load_npz(filename, target[, verbose, check])

load_pkl(filename, target[, verbose, check])

load_mat(filename, target[, verbose, check])

brainpy.math module

The math module for whole BrainPy ecosystem. This module provides basic mathematical operations, including:

  • numpy-like array operations

  • linear algebra functions

  • random sampling functions

  • discrete fourier transform functions

  • compilations of jit, vmap, pmap for class objects

  • automatic differentiation of grad, jacocian, hessian, etc. for class objects

  • loss functions

  • activation functions

  • optimization classes

Details in the following.

General Functions

The math module for whole BrainPy ecosystem. This module provides basic mathematical operations, including:

  • numpy-like array operations

  • linear algebra functions

  • random sampling functions

  • discrete fourier transform functions

  • compilations of jit, vmap, pmap for class objects

  • automatic differentiation of grad, jacocian, hessian, etc. for class objects

  • loss functions

  • activation functions

  • optimization classes

Details in the following.

use_backend(name[, module])

get_backend_name()

Get the current backend name.

set_dt(dt)

Set the numerical integrator precision.

get_dt()

Get the numerical integrator precision.

set_int_(int_type)

Set the default int type.

set_float_(float_type)

Set the default float type.

set_complex_(complex_type)

Set the default complex type.

JAX backend Supports

Compilations

jit(obj_or_func[, dyn_vars, ...])

JIT (Just-In-Time) Compilation for JAX backend.

Variables

Variable(value[, type, replicate])

The pointer to specify the dynamical variable.

TrainVar(value[, replicate])

The pointer to specify the trainable variable.

Parameter(value[, replicate])

The pointer to specify the parameter.

Functions

function([f, nodes, dyn_vars, name])

NumPy backend Supports

Compilations

jit(obj_or_fun[, nopython, fastmath, ...])

Just-In-Time (JIT) Compilation in NumPy backend.

Variables

Variable(value[, type, replicate])

Variable.

TrainVar(value[, replicate])

Trainable Variable.

Parameter(value[, replicate])

Parameter.

Functions

function([f, nodes, dyn_vars, name])

JAX Special Supports

Parallel Compilation

The parallel compilation tools for JAX backend.

  1. Vectorize compilation is implemented by the ‘vmap()’ function

  2. Parallel compilation is implemented by the ‘pmap()’ function

vmap(obj_or_func[, dyn_vars, vars_batched, ...])

Vectorization compilation in JAX backend.

pmap(obj_or_func[, dyn_vars, axis_name, ...])

Parallel compilation in JAX backend.

Operators

pre2syn(pre_values, pre_ids)

syn2post(syn_values, post_ids, post_num)

segment_sum(data, segment_ids, num_segments)

Computes the sum within segments of an array.

segment_prod(data, segment_ids, num_segments)

Computes the product within segments of an array.

segment_max(data, segment_ids, num_segments)

Computes the product within segments of an array.

segment_min(data, segment_ids, num_segments)

Computes the product within segments of an array.

Control Flows

make_loop(body_fun, dyn_vars[, out_vars, ...])

Make a for-loop function, which iterate over inputs.

make_while(cond_fun, body_fun, dyn_vars)

Make a while-loop function.

make_cond(true_fun, false_fun[, dyn_vars])

Make a condition (if-else) function.

Automatic Differentiation

grad(func[, dyn_vars, grad_vars, argnums, ...])

Automatic Gradient Computation in JAX backend.

jacobian(func[, dyn_vars, grad_vars, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

jacrev(func[, dyn_vars, grad_vars, argnums, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

jacfwd(func[, dyn_vars, grad_vars, argnums, ...])

Jacobian of fun evaluated column-by-column using forward-mode AD.

hessian(fun[, vars, grad_vars, argnums, ...])

Hessian of fun as a dense array.

Grad(fun, grad_vars, grad_tree, vars[, ...])

Compute the gradients of trainable variables for the given object.

Jacobian(fun, vars, grad_vars, grad_tree[, ...])

Base Class to Compute Jacobian Matrix.

class brainpy.math.jax.autograd.Grad(fun, grad_vars, grad_tree, vars, argnums=None, has_aux=None, holomorphic=False, allow_int=False, reduce_axes=(), return_value=False, name=None)[source]

Compute the gradients of trainable variables for the given object.

Examples

This example is that we return two auxiliary data, i.e., has_aux=True.

>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> class Test(bp.Base):
>>>   def __init__(self):
>>>     super(Test, self).__init__()
>>>     self.a = bm.TrainVar(bp.math.ones(1))
>>>     self.b = bm.TrainVar(bp.math.ones(1))
>>>
>>>   def __call__(self, c):
>>>     ab = self.a * self.b
>>>     ab2 = ab * 2
>>>     vv = ab2 + c
>>>     return vv, (ab, ab2)
>>>
>>> test = Test()
>>> test_grad = Grad(test, test.vars(), argnums=0, has_aux=True)
>>> grads, outputs = test_grad(10.)
>>> grads
(DeviceArray(1., dtype=float32),
 {'Test3.a': DeviceArray([2.], dtype=float32), 'Test3.b': DeviceArray([2.], dtype=float32)})
>>> outputs
(JaxArray(DeviceArray([1.], dtype=float32)),
 JaxArray(DeviceArray([2.], dtype=float32)))

This example is that we return two auxiliary data, i.e., has_aux=True.

>>> import brainpy as bp
>>>
>>> class Test(bp.dnn.Module):
>>>   def __init__(self):
>>>     super(Test, self).__init__()
>>>     self.a = bp.TrainVar(bp.math.ones(1))
>>>     self.b = bp.TrainVar(bp.math.ones(1))
>>>
>>>   def __call__(self, c):
>>>     ab = self.a * self.b
>>>     ab2 = ab * 2
>>>     vv = ab2 + c
>>>     return vv, (ab, ab2)
>>>
>>> test = Test()
>>> test_grad = ValueAndGrad(test, argnums=0, has_aux=True)
>>> outputs, grads = test_grad(10.)
>>> grads
(DeviceArray(1., dtype=float32),
 {'Test3.a': DeviceArray([2.], dtype=float32), 'Test3.b': DeviceArray([2.], dtype=float32)})
>>> outputs
(JaxArray(DeviceArray(12., dtype=float32)),
 (JaxArray(DeviceArray([1.], dtype=float32)),
  JaxArray(DeviceArray([2.], dtype=float32))))
class brainpy.math.jax.autograd.Jacobian(fun, vars, grad_vars, grad_tree, argnums=None, holomorphic=False, name=None, allow_int=False, has_aux=None, return_value=False, method='rev')[source]

Base Class to Compute Jacobian Matrix.

Activation Functions

This module provides commonly used activation functions.

Activation functions are a critical part of the design of a neural network. The choice of activation function in the hidden layer will control how well the network model learns the training dataset. The choice of activation function in the output layer will define the type of predictions the model can make.

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU activation function

hard_swish(x)

Hard SiLU activation function

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis])

Log-Softmax function.

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indicies.

normalize(x[, axis, mean, variance, epsilon])

Normalizes an array by subtracting mean and dividing by sqrt(var).

relu(x)

relu6(x)

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis])

Softmax function.

softplus(x)

Softplus activation function.

silu(x)

SiLU activation function.

swish(x)

SiLU activation function.

selu(x)

Scaled exponential linear unit activation.

Loss Functions

This module implements many commonly used loss functions.

The references used are included:

cross_entropy_loss(logits, targets[, ...])

This criterion combines LogSoftmax and NLLLoss` in one single class.

l1_loos(logits, targets[, reduction])

Creates a criterion that measures the mean absolute error (MAE) between each element in

l2_loss(predicts, targets)

Computes the L2 loss.

l2_norm(x)

Computes the L2 loss.

huber_loss(predicts, targets[, delta])

Huber loss.

mean_absolute_error(x, y[, axis])

Computes the mean absolute error between x and y.

mean_squared_error(predicts, targets[, axis])

Computes the mean squared error between x and y.

mean_squared_log_error(y_true, y_pred[, axis])

Computes the mean squared logarithmic error between y_true and y_pred.

Optimizers

make_schedule(scalar_or_schedule)

Optimizer(train_vars, lr, name)

Base Optimizer Class.

SGD(lr, train_vars[, name])

Stochastic gradient descent optimizer.

Momentum(lr, train_vars[, momentum, name])

Momentum optimizer.

MomentumNesterov(lr, train_vars[, momentum, ...])

Nesterov accelerated gradient optimizer 2.

Adagrad(lr, train_vars[, epsilon, name])

Optimizer that implements the Adagrad algorithm.

Adadelta(train_vars[, lr, epsilon, rho, name])

Optimizer that implements the Adadelta algorithm.

RMSProp(lr, train_vars[, epsilon, rho, name])

Optimizer that implements the RMSprop algorithm.

Adam(lr, train_vars[, beta1, beta2, eps, name])

Optimizer that implements the Adam algorithm.

Scheduler(lr)

The learning rate scheduler.

Constant(lr)

ExponentialDecay(lr, decay_steps, decay_rate)

InverseTimeDecay(lr, decay_steps, decay_rate)

PolynomialDecay(lr, decay_steps, final_lr[, ...])

PiecewiseConstant(boundaries, values)

class brainpy.math.jax.optimizers.Optimizer(train_vars, lr, name)[source]

Base Optimizer Class.

target_backend = 'jax'

Used to specify the target backend which the model to run.

class brainpy.math.jax.optimizers.SGD(lr, train_vars, name=None)[source]

Stochastic gradient descent optimizer.

SGD performs a parameter update for training examples \(x\) and label \(y\):

\[\theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)\]
class brainpy.math.jax.optimizers.Momentum(lr, train_vars, momentum=0.9, name=None)[source]

Momentum optimizer.

Momentum 1 is a method that helps accelerate SGD in the relevant direction and dampens oscillations. It does this by adding a fraction \(\gamma\) of the update vector of the past time step to the current update vector:

\[egin{align} egin{split} v_t &= \gamma v_{t-1} + \eta\]
abla_ heta J( heta)

heta &= heta - v_t

end{split} end{align}

1

Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6

class brainpy.math.jax.optimizers.MomentumNesterov(lr, train_vars, momentum=0.9, name=None)[source]

Nesterov accelerated gradient optimizer 2.

\[egin{align} egin{split} v_t &= \gamma v_{t-1} + \eta\]
abla_ heta J( heta - gamma v_{t-1} )

heta &= heta - v_t

end{split} end{align}

2(1,2)

Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.

class brainpy.math.jax.optimizers.Adagrad(lr, train_vars, epsilon=1e-06, name=None)[source]

Optimizer that implements the Adagrad algorithm.

Adagrad 3 is an optimizer with parameter-specific learning rates, which are adapted relative to how frequently a parameter gets updated during training. The more updates a parameter receives, the smaller the updates.

\[heta_{t+1} = heta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}\]

where \(G(t)\) contains the sum of the squares of the past gradients

One of Adagrad’s main benefits is that it eliminates the need to manually tune the learning rate. Most implementations use a default value of 0.01 and leave it at that. Adagrad’s main weakness is its accumulation of the squared gradients in the denominator: Since every added term is positive, the accumulated sum keeps growing during training. This in turn causes the learning rate to shrink and eventually become infinitesimally small, at which point the algorithm is no longer able to acquire additional knowledge.

References

3

Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html

class brainpy.math.jax.optimizers.Adadelta(train_vars, lr=0.01, epsilon=1e-06, rho=0.95, name=None)[source]

Optimizer that implements the Adadelta algorithm.

Adadelta 4 optimization is a stochastic gradient descent method that is based on adaptive learning rate per dimension to address two drawbacks:

  • The continual decay of learning rates throughout training.

  • The need for a manually selected global learning rate.

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving window of gradient updates, instead of accumulating all past gradients. This way, Adadelta continues learning even when many updates have been done. Compared to Adagrad, in the original version of Adadelta you don’t have to set an initial learning rate.

\[oldsymbol{s}_t \leftarrow\]

ho oldsymbol{s}_{t-1} + (1 - ho) oldsymbol{g}_t odot oldsymbol{g}_t,

oldsymbol{g}_t’ leftarrow sqrt{

rac{Deltaoldsymbol{x}_{t-1} + epsilon}{oldsymbol{s}_t + epsilon}} odot oldsymbol{g}_t,

oldsymbol{x}_t leftarrow oldsymbol{x}_{t-1} - oldsymbol{g}’_t, Deltaoldsymbol{x}_t leftarrow

ho Deltaoldsymbol{x}_{t-1} + (1 - ho) oldsymbol{g}’_t odot oldsymbol{g}’_t.

:math:`

ho` should be between 0 and 1. A value of rho close to 1 will decay the

moving average slowly and a value close to 0 will decay the moving average fast.

:math:`

ho` = 0.95 and :math:`epsilon`=1e-6 are suggested in the paper and reported

to work for multiple datasets (MNIST, speech).

In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to keep it at this value. epsilon is important for the very first update (so the numerator does not become 0).

4

Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701

class brainpy.math.jax.optimizers.RMSProp(lr, train_vars, epsilon=1e-06, rho=0.9, name=None)[source]

Optimizer that implements the RMSprop algorithm.

RMSprop 5 and Adadelta have both been developed independently around the same time stemming from the need to resolve Adagrad’s radically diminishing learning rates.

The gist of RMSprop is to:

  • Maintain a moving (discounted) average of the square of gradients

  • Divide the gradient by the root of this average

\[egin{split}c_t &=\]

ho c_{t-1} + (1- ho)*g^2

p_t &=

rac{eta}{sqrt{c_t + epsilon}} * g end{split}

The centered version additionally maintains a moving average of the gradients, and uses that average to estimate the variance.

5

Tieleman, T. and Hinton, G. (2012): Neural Networks for Machine Learning, Lecture 6.5 - rmsprop. Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)

class brainpy.math.jax.optimizers.Adam(lr, train_vars, beta1=0.9, beta2=0.999, eps=1e-08, name=None)[source]

Optimizer that implements the Adam algorithm.

Adam 6 - a stochastic gradient descent method (SGD) that computes individual adaptive learning rates for different parameters from estimates of first- and second-order moments of the gradients.

Parameters
  • beta1 (optional, float) – A positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).

  • beta2 (optional, float) – A positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).

  • eps (optional, float) – A positive scalar value for epsilon, a small constant for numerical stability (default 1e-8).

  • name (optional, str) – The optimizer name.

References

6

Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.

class brainpy.math.jax.optimizers.Scheduler(lr)[source]

The learning rate scheduler.

class brainpy.math.jax.optimizers.Constant(lr)[source]
class brainpy.math.jax.optimizers.ExponentialDecay(lr, decay_steps, decay_rate)[source]
class brainpy.math.jax.optimizers.InverseTimeDecay(lr, decay_steps, decay_rate, staircase=False)[source]
class brainpy.math.jax.optimizers.PolynomialDecay(lr, decay_steps, final_lr, power=1.0)[source]
class brainpy.math.jax.optimizers.PiecewiseConstant(boundaries, values)[source]

Comparison Table

Here is a list of NumPy APIs and its corresponding BrainPy implementations.

- in BrainPy column denotes that implementation is not provided yet. We welcome contributions for these functions.

Multi-dimensional Array

NumPy

brainpy.math.numpy

brainpy.math.jax

numpy.ndarray.all()

brainpy.math.numpy.ndarray.all()

brainpy.math.jax.ndarray.all()

numpy.ndarray.any()

brainpy.math.numpy.ndarray.any()

brainpy.math.jax.ndarray.any()

numpy.ndarray.argmax()

brainpy.math.numpy.ndarray.argmax()

brainpy.math.jax.ndarray.argmax()

numpy.ndarray.argmin()

brainpy.math.numpy.ndarray.argmin()

brainpy.math.jax.ndarray.argmin()

numpy.ndarray.argpartition()

brainpy.math.numpy.ndarray.argpartition()

brainpy.math.jax.ndarray.argpartition()

numpy.ndarray.argsort()

brainpy.math.numpy.ndarray.argsort()

brainpy.math.jax.ndarray.argsort()

numpy.ndarray.astype()

brainpy.math.numpy.ndarray.astype()

brainpy.math.jax.ndarray.astype()

numpy.ndarray.byteswap()

brainpy.math.numpy.ndarray.byteswap()

brainpy.math.jax.ndarray.byteswap()

numpy.ndarray.choose()

brainpy.math.numpy.ndarray.choose()

brainpy.math.jax.ndarray.choose()

numpy.ndarray.clip()

brainpy.math.numpy.ndarray.clip()

brainpy.math.jax.ndarray.clip()

numpy.ndarray.compress()

brainpy.math.numpy.ndarray.compress()

brainpy.math.jax.ndarray.compress()

numpy.ndarray.conj()

brainpy.math.numpy.ndarray.conj()

brainpy.math.jax.ndarray.conj()

numpy.ndarray.conjugate()

brainpy.math.numpy.ndarray.conjugate()

brainpy.math.jax.ndarray.conjugate()

numpy.ndarray.copy()

brainpy.math.numpy.ndarray.copy()

brainpy.math.jax.ndarray.copy()

numpy.ndarray.cumprod()

brainpy.math.numpy.ndarray.cumprod()

brainpy.math.jax.ndarray.cumprod()

numpy.ndarray.cumsum()

brainpy.math.numpy.ndarray.cumsum()

brainpy.math.jax.ndarray.cumsum()

numpy.ndarray.diagonal()

brainpy.math.numpy.ndarray.diagonal()

brainpy.math.jax.ndarray.diagonal()

numpy.ndarray.dot()

brainpy.math.numpy.ndarray.dot()

brainpy.math.jax.ndarray.dot()

numpy.ndarray.dump()

brainpy.math.numpy.ndarray.dump()

-

numpy.ndarray.dumps()

brainpy.math.numpy.ndarray.dumps()

-

numpy.ndarray.fill()

brainpy.math.numpy.ndarray.fill()

brainpy.math.jax.ndarray.fill()

numpy.ndarray.flatten()

brainpy.math.numpy.ndarray.flatten()

brainpy.math.jax.ndarray.flatten()

numpy.ndarray.getfield()

brainpy.math.numpy.ndarray.getfield()

-

numpy.ndarray.item()

brainpy.math.numpy.ndarray.item()

brainpy.math.jax.ndarray.item()

numpy.ndarray.itemset()

brainpy.math.numpy.ndarray.itemset()

-

numpy.ndarray.max()

brainpy.math.numpy.ndarray.max()

brainpy.math.jax.ndarray.max()

numpy.ndarray.mean()

brainpy.math.numpy.ndarray.mean()

brainpy.math.jax.ndarray.mean()

numpy.ndarray.min()

brainpy.math.numpy.ndarray.min()

brainpy.math.jax.ndarray.min()

numpy.ndarray.newbyteorder()

brainpy.math.numpy.ndarray.newbyteorder()

-

numpy.ndarray.nonzero()

brainpy.math.numpy.ndarray.nonzero()

brainpy.math.jax.ndarray.nonzero()

numpy.ndarray.partition()

brainpy.math.numpy.ndarray.partition()

-

numpy.ndarray.prod()

brainpy.math.numpy.ndarray.prod()

brainpy.math.jax.ndarray.prod()

numpy.ndarray.ptp()

brainpy.math.numpy.ndarray.ptp()

brainpy.math.jax.ndarray.ptp()

numpy.ndarray.put()

brainpy.math.numpy.ndarray.put()

-

numpy.ndarray.ravel()

brainpy.math.numpy.ndarray.ravel()

brainpy.math.jax.ndarray.ravel()

numpy.ndarray.repeat()

brainpy.math.numpy.ndarray.repeat()

brainpy.math.jax.ndarray.repeat()

numpy.ndarray.reshape()

brainpy.math.numpy.ndarray.reshape()

brainpy.math.jax.ndarray.reshape()

numpy.ndarray.resize()

brainpy.math.numpy.ndarray.resize()

-

numpy.ndarray.round()

brainpy.math.numpy.ndarray.round()

brainpy.math.jax.ndarray.round()

numpy.ndarray.searchsorted()

brainpy.math.numpy.ndarray.searchsorted()

brainpy.math.jax.ndarray.searchsorted()

numpy.ndarray.setfield()

brainpy.math.numpy.ndarray.setfield()

-

numpy.ndarray.setflags()

brainpy.math.numpy.ndarray.setflags()

-

numpy.ndarray.sort()

brainpy.math.numpy.ndarray.sort()

brainpy.math.jax.ndarray.sort()

numpy.ndarray.squeeze()

brainpy.math.numpy.ndarray.squeeze()

brainpy.math.jax.ndarray.squeeze()

numpy.ndarray.std()

brainpy.math.numpy.ndarray.std()

brainpy.math.jax.ndarray.std()

numpy.ndarray.sum()

brainpy.math.numpy.ndarray.sum()

brainpy.math.jax.ndarray.sum()

numpy.ndarray.swapaxes()

brainpy.math.numpy.ndarray.swapaxes()

brainpy.math.jax.ndarray.swapaxes()

numpy.ndarray.take()

brainpy.math.numpy.ndarray.take()

brainpy.math.jax.ndarray.take()

numpy.ndarray.tobytes()

brainpy.math.numpy.ndarray.tobytes()

brainpy.math.jax.ndarray.tobytes()

numpy.ndarray.tofile()

brainpy.math.numpy.ndarray.tofile()

-

numpy.ndarray.tolist()

brainpy.math.numpy.ndarray.tolist()

brainpy.math.jax.ndarray.tolist()

numpy.ndarray.tostring()

brainpy.math.numpy.ndarray.tostring()

-

numpy.ndarray.trace()

brainpy.math.numpy.ndarray.trace()

brainpy.math.jax.ndarray.trace()

numpy.ndarray.transpose()

brainpy.math.numpy.ndarray.transpose()

brainpy.math.jax.ndarray.transpose()

numpy.ndarray.var()

brainpy.math.numpy.ndarray.var()

brainpy.math.jax.ndarray.var()

numpy.ndarray.view()

brainpy.math.numpy.ndarray.view()

brainpy.math.jax.ndarray.view()

-

-

brainpy.math.jax.ndarray.block_host_until_ready()

-

-

brainpy.math.jax.ndarray.block_until_ready()

-

-

brainpy.math.jax.ndarray.numpy()

-

-

brainpy.math.jax.ndarray.split()

-

-

brainpy.math.jax.ndarray.tile()

Summary

  • Number of NumPy functions: 56

  • Number of functions covered by brainpy.math.numpy: 56

  • Number of functions covered by brainpy.math.jax: 44

Array Operations

NumPy

brainpy.math.numpy

brainpy.math.jax

numpy.abs

brainpy.math.numpy.abs

brainpy.math.jax.abs

numpy.absolute

brainpy.math.numpy.absolute

brainpy.math.jax.absolute

numpy.add

brainpy.math.numpy.add

brainpy.math.jax.add

numpy.add_docstring

-

-

numpy.add_newdoc

-

-

numpy.add_newdoc_ufunc

-

-

numpy.alen

-

-

numpy.all

brainpy.math.numpy.all

brainpy.math.jax.all

numpy.allclose

brainpy.math.numpy.allclose

brainpy.math.jax.allclose

numpy.alltrue

-

-

numpy.amax

-

-

numpy.amin

-

-

numpy.angle

brainpy.math.numpy.angle

brainpy.math.jax.angle

numpy.any

brainpy.math.numpy.any

brainpy.math.jax.any

numpy.append

brainpy.math.numpy.append

brainpy.math.jax.append

numpy.apply_along_axis

-

-

numpy.apply_over_axes

-

-

numpy.arange

brainpy.math.numpy.arange

brainpy.math.jax.arange

numpy.arccos

brainpy.math.numpy.arccos

brainpy.math.jax.arccos

numpy.arccosh

brainpy.math.numpy.arccosh

brainpy.math.jax.arccosh

numpy.arcsin

brainpy.math.numpy.arcsin

brainpy.math.jax.arcsin

numpy.arcsinh

brainpy.math.numpy.arcsinh

brainpy.math.jax.arcsinh

numpy.arctan

brainpy.math.numpy.arctan

brainpy.math.jax.arctan

numpy.arctan2

brainpy.math.numpy.arctan2

brainpy.math.jax.arctan2

numpy.arctanh

brainpy.math.numpy.arctanh

brainpy.math.jax.arctanh

numpy.argmax

brainpy.math.numpy.argmax

brainpy.math.jax.argmax

numpy.argmin

brainpy.math.numpy.argmin

brainpy.math.jax.argmin

numpy.argpartition

-

-

numpy.argsort

brainpy.math.numpy.argsort

brainpy.math.jax.argsort

numpy.argwhere

brainpy.math.numpy.argwhere

brainpy.math.jax.argwhere

numpy.around

brainpy.math.numpy.around

brainpy.math.jax.around

numpy.array

brainpy.math.numpy.array

brainpy.math.jax.array

numpy.array2string

-

-

numpy.array_equal

brainpy.math.numpy.array_equal

brainpy.math.jax.array_equal

numpy.array_equiv

-

-

numpy.array_repr

-

-

numpy.array_split

-

-

numpy.array_str

-

-

numpy.asanyarray

-

-

numpy.asarray

brainpy.math.numpy.asarray

brainpy.math.jax.asarray

numpy.asarray_chkfinite

-

-

numpy.ascontiguousarray

-

-

numpy.asfarray

-

-

numpy.asfortranarray

-

-

numpy.asmatrix

-

-

numpy.asscalar

-

-

numpy.atleast_1d

brainpy.math.numpy.atleast_1d

brainpy.math.jax.atleast_1d

numpy.atleast_2d

brainpy.math.numpy.atleast_2d

brainpy.math.jax.atleast_2d

numpy.atleast_3d

brainpy.math.numpy.atleast_3d

brainpy.math.jax.atleast_3d

numpy.average

brainpy.math.numpy.average

brainpy.math.jax.average

numpy.bartlett

brainpy.math.numpy.bartlett

brainpy.math.jax.bartlett

numpy.base_repr

-

-

numpy.binary_repr

-

-

numpy.bincount

brainpy.math.numpy.bincount

brainpy.math.jax.bincount

numpy.bitwise_and

brainpy.math.numpy.bitwise_and

brainpy.math.jax.bitwise_and

numpy.bitwise_not

brainpy.math.numpy.bitwise_not

brainpy.math.jax.bitwise_not

numpy.bitwise_or

brainpy.math.numpy.bitwise_or

brainpy.math.jax.bitwise_or

numpy.bitwise_xor

brainpy.math.numpy.bitwise_xor

brainpy.math.jax.bitwise_xor

numpy.blackman

brainpy.math.numpy.blackman

brainpy.math.jax.blackman

numpy.block

-

-

numpy.bmat

-

-

numpy.broadcast_arrays

-

-

numpy.broadcast_shapes

-

-

numpy.broadcast_to

-

-

numpy.busday_count

-

-

numpy.busday_offset

-

-

numpy.byte_bounds

-

-

numpy.can_cast

-

-

numpy.cbrt

brainpy.math.numpy.cbrt

brainpy.math.jax.cbrt

numpy.ceil

brainpy.math.numpy.ceil

brainpy.math.jax.ceil

numpy.choose

-

-

numpy.clip

brainpy.math.numpy.clip

brainpy.math.jax.clip

numpy.column_stack

brainpy.math.numpy.column_stack

brainpy.math.jax.column_stack

numpy.common_type

-

-

numpy.compare_chararrays

-

-

numpy.compress

-

-

numpy.concatenate

brainpy.math.numpy.concatenate

brainpy.math.jax.concatenate

numpy.conj

brainpy.math.numpy.conj

brainpy.math.jax.conj

numpy.conjugate

brainpy.math.numpy.conjugate

brainpy.math.jax.conjugate

numpy.convolve

brainpy.math.numpy.convolve

brainpy.math.jax.convolve

numpy.copy

-

-

numpy.copysign

brainpy.math.numpy.copysign

brainpy.math.jax.copysign

numpy.copyto

-

-

numpy.corrcoef

brainpy.math.numpy.corrcoef

brainpy.math.jax.corrcoef

numpy.correlate

brainpy.math.numpy.correlate

brainpy.math.jax.correlate

numpy.cos

brainpy.math.numpy.cos

brainpy.math.jax.cos

numpy.cosh

brainpy.math.numpy.cosh

brainpy.math.jax.cosh

numpy.count_nonzero

brainpy.math.numpy.count_nonzero

brainpy.math.jax.count_nonzero

numpy.cov

brainpy.math.numpy.cov

brainpy.math.jax.cov

numpy.cross

brainpy.math.numpy.cross

brainpy.math.jax.cross

numpy.cumprod

brainpy.math.numpy.cumprod

brainpy.math.jax.cumprod

numpy.cumproduct

-

-

numpy.cumsum

brainpy.math.numpy.cumsum

brainpy.math.jax.cumsum

numpy.datetime_as_string

-

-

numpy.datetime_data

-

-

numpy.deg2rad

brainpy.math.numpy.deg2rad

brainpy.math.jax.deg2rad

numpy.degrees

brainpy.math.numpy.degrees

brainpy.math.jax.degrees

numpy.delete

-

-

numpy.deprecate

-

-

numpy.deprecate_with_doc

-

-

numpy.diag

brainpy.math.numpy.diag

brainpy.math.jax.diag

numpy.diag_indices

-

-

numpy.diag_indices_from

-

-

numpy.diagflat

-

-

numpy.diagonal

-

-

numpy.diff

brainpy.math.numpy.diff

brainpy.math.jax.diff

numpy.digitize

brainpy.math.numpy.digitize

brainpy.math.jax.digitize

numpy.disp

-

-

numpy.divide

brainpy.math.numpy.divide

brainpy.math.jax.divide

numpy.divmod

brainpy.math.numpy.divmod

brainpy.math.jax.divmod

numpy.dot

brainpy.math.numpy.dot

brainpy.math.jax.dot

numpy.dsplit

brainpy.math.numpy.dsplit

brainpy.math.jax.dsplit

numpy.dstack

brainpy.math.numpy.dstack

brainpy.math.jax.dstack

numpy.ediff1d

brainpy.math.numpy.ediff1d

brainpy.math.jax.ediff1d

numpy.einsum

-

-

numpy.einsum_path

-

-

numpy.empty

brainpy.math.numpy.empty

brainpy.math.jax.empty

numpy.empty_like

brainpy.math.numpy.empty_like

brainpy.math.jax.empty_like

numpy.equal

brainpy.math.numpy.equal

brainpy.math.jax.equal

numpy.exp

brainpy.math.numpy.exp

brainpy.math.jax.exp

numpy.exp2

brainpy.math.numpy.exp2

brainpy.math.jax.exp2

numpy.expand_dims

brainpy.math.numpy.expand_dims

brainpy.math.jax.expand_dims

numpy.expm1

brainpy.math.numpy.expm1

brainpy.math.jax.expm1

numpy.extract

brainpy.math.numpy.extract

brainpy.math.jax.extract

numpy.eye

brainpy.math.numpy.eye

brainpy.math.jax.eye

numpy.fabs

brainpy.math.numpy.fabs

brainpy.math.jax.fabs

numpy.fastCopyAndTranspose

-

-

numpy.fill_diagonal

brainpy.math.numpy.fill_diagonal

brainpy.math.jax.fill_diagonal

numpy.find_common_type

-

-

numpy.fix

brainpy.math.numpy.fix

brainpy.math.jax.fix

numpy.flatnonzero

brainpy.math.numpy.flatnonzero

brainpy.math.jax.flatnonzero

numpy.flip

brainpy.math.numpy.flip

brainpy.math.jax.flip

numpy.fliplr

brainpy.math.numpy.fliplr

brainpy.math.jax.fliplr

numpy.flipud

brainpy.math.numpy.flipud

brainpy.math.jax.flipud

numpy.float_power

brainpy.math.numpy.float_power

brainpy.math.jax.float_power

numpy.floor

brainpy.math.numpy.floor

brainpy.math.jax.floor

numpy.floor_divide

brainpy.math.numpy.floor_divide

brainpy.math.jax.floor_divide

numpy.fmax

brainpy.math.numpy.fmax

brainpy.math.jax.fmax

numpy.fmin

brainpy.math.numpy.fmin

brainpy.math.jax.fmin

numpy.fmod

brainpy.math.numpy.fmod

brainpy.math.jax.fmod

numpy.format_float_positional

-

-

numpy.format_float_scientific

-

-

numpy.frexp

brainpy.math.numpy.frexp

brainpy.math.jax.frexp

numpy.frombuffer

-

-

numpy.fromfile

-

-

numpy.fromfunction

-

-

numpy.fromiter

-

-

numpy.frompyfunc

-

-

numpy.fromregex

-

-

numpy.fromstring

-

-

numpy.full

brainpy.math.numpy.full

brainpy.math.jax.full

numpy.full_like

brainpy.math.numpy.full_like

brainpy.math.jax.full_like

numpy.gcd

brainpy.math.numpy.gcd

brainpy.math.jax.gcd

numpy.genfromtxt

-

-

numpy.geomspace

-

-

numpy.get_array_wrap

-

-

numpy.get_include

-

-

numpy.get_printoptions

-

-

numpy.getbufsize

-

-

numpy.geterr

-

-

numpy.geterrcall

-

-

numpy.geterrobj

-

-

numpy.gradient

-

-

numpy.greater

brainpy.math.numpy.greater

brainpy.math.jax.greater

numpy.greater_equal

brainpy.math.numpy.greater_equal

brainpy.math.jax.greater_equal

numpy.hamming

brainpy.math.numpy.hamming

brainpy.math.jax.hamming

numpy.hanning

brainpy.math.numpy.hanning

brainpy.math.jax.hanning

numpy.heaviside

brainpy.math.numpy.heaviside

brainpy.math.jax.heaviside

numpy.histogram

brainpy.math.numpy.histogram

brainpy.math.jax.histogram

numpy.histogram2d

-

-

numpy.histogram_bin_edges

-

-

numpy.histogramdd

-

-

numpy.hsplit

brainpy.math.numpy.hsplit

brainpy.math.jax.hsplit

numpy.hstack

brainpy.math.numpy.hstack

brainpy.math.jax.hstack

numpy.hypot

brainpy.math.numpy.hypot

brainpy.math.jax.hypot

numpy.i0

-

-

numpy.identity

brainpy.math.numpy.identity

brainpy.math.jax.identity

numpy.imag

brainpy.math.numpy.imag

brainpy.math.jax.imag

numpy.in1d

-

-

numpy.indices

-

-

numpy.info

-

-

numpy.inner

brainpy.math.numpy.inner

brainpy.math.jax.inner

numpy.insert

-

-

numpy.interp

brainpy.math.numpy.interp

brainpy.math.jax.interp

numpy.intersect1d

-

-

numpy.invert

brainpy.math.numpy.invert

brainpy.math.jax.invert

numpy.is_busday

-

-

numpy.isclose

brainpy.math.numpy.isclose

brainpy.math.jax.isclose

numpy.iscomplex

-

-

numpy.iscomplexobj

-

-

numpy.isfinite

brainpy.math.numpy.isfinite

brainpy.math.jax.isfinite

numpy.isfortran

-

-

numpy.isin

-

-

numpy.isinf

brainpy.math.numpy.isinf

brainpy.math.jax.isinf

numpy.isnan

brainpy.math.numpy.isnan

brainpy.math.jax.isnan

numpy.isnat

-

-

numpy.isneginf

-

-

numpy.isposinf

-

-

numpy.isreal

brainpy.math.numpy.isreal

brainpy.math.jax.isreal

numpy.isrealobj

-

-

numpy.isscalar

brainpy.math.numpy.isscalar

brainpy.math.jax.isscalar

numpy.issctype

-

-

numpy.issubclass_

-

-

numpy.issubdtype

-

-

numpy.issubsctype

-

-

numpy.iterable

-

-

numpy.ix_

-

-

numpy.kaiser

brainpy.math.numpy.kaiser

brainpy.math.jax.kaiser

numpy.kron

brainpy.math.numpy.kron

brainpy.math.jax.kron

numpy.lcm

brainpy.math.numpy.lcm

brainpy.math.jax.lcm

numpy.ldexp

brainpy.math.numpy.ldexp

brainpy.math.jax.ldexp

numpy.left_shift

brainpy.math.numpy.left_shift

brainpy.math.jax.left_shift

numpy.less

brainpy.math.numpy.less

brainpy.math.jax.less

numpy.less_equal

brainpy.math.numpy.less_equal

brainpy.math.jax.less_equal

numpy.lexsort

-

-

numpy.linspace

brainpy.math.numpy.linspace

brainpy.math.jax.linspace

numpy.load

-

-

numpy.loads

-

-

numpy.loadtxt

-

-

numpy.log

brainpy.math.numpy.log

brainpy.math.jax.log

numpy.log10

brainpy.math.numpy.log10

brainpy.math.jax.log10

numpy.log1p

brainpy.math.numpy.log1p

brainpy.math.jax.log1p

numpy.log2

brainpy.math.numpy.log2

brainpy.math.jax.log2

numpy.logaddexp

brainpy.math.numpy.logaddexp

brainpy.math.jax.logaddexp

numpy.logaddexp2

brainpy.math.numpy.logaddexp2

brainpy.math.jax.logaddexp2

numpy.logical_and

brainpy.math.numpy.logical_and

brainpy.math.jax.logical_and

numpy.logical_not

brainpy.math.numpy.logical_not

brainpy.math.jax.logical_not

numpy.logical_or

brainpy.math.numpy.logical_or

brainpy.math.jax.logical_or

numpy.logical_xor

brainpy.math.numpy.logical_xor

brainpy.math.jax.logical_xor

numpy.logspace

brainpy.math.numpy.logspace

brainpy.math.jax.logspace

numpy.lookfor

-

-

numpy.mafromtxt

-

-

numpy.mask_indices

-

-

numpy.mat

-

-

numpy.matmul

brainpy.math.numpy.matmul

brainpy.math.jax.matmul

numpy.max

brainpy.math.numpy.max

brainpy.math.jax.max

numpy.maximum

brainpy.math.numpy.maximum

brainpy.math.jax.maximum

numpy.maximum_sctype

-

-

numpy.may_share_memory

-

-

numpy.mean

brainpy.math.numpy.mean

brainpy.math.jax.mean

numpy.median

brainpy.math.numpy.median

brainpy.math.jax.median

numpy.meshgrid

brainpy.math.numpy.meshgrid

brainpy.math.jax.meshgrid

numpy.min

brainpy.math.numpy.min

brainpy.math.jax.min

numpy.min_scalar_type

-

-

numpy.minimum

brainpy.math.numpy.minimum

brainpy.math.jax.minimum

numpy.mintypecode

-

-

numpy.mod

brainpy.math.numpy.mod

brainpy.math.jax.mod

numpy.modf

brainpy.math.numpy.modf

brainpy.math.jax.modf

numpy.moveaxis

brainpy.math.numpy.moveaxis

brainpy.math.jax.moveaxis

numpy.msort

-

-

numpy.multiply

brainpy.math.numpy.multiply

brainpy.math.jax.multiply

numpy.nan_to_num

-

-

numpy.nanargmax

-

-

numpy.nanargmin

-

-

numpy.nancumprod

brainpy.math.numpy.nancumprod

brainpy.math.jax.nancumprod

numpy.nancumsum

brainpy.math.numpy.nancumsum

brainpy.math.jax.nancumsum

numpy.nanmax

brainpy.math.numpy.nanmax

brainpy.math.jax.nanmax

numpy.nanmean

brainpy.math.numpy.nanmean

brainpy.math.jax.nanmean

numpy.nanmedian

brainpy.math.numpy.nanmedian

brainpy.math.jax.nanmedian

numpy.nanmin

brainpy.math.numpy.nanmin

brainpy.math.jax.nanmin

numpy.nanpercentile

brainpy.math.numpy.nanpercentile

brainpy.math.jax.nanpercentile

numpy.nanprod

brainpy.math.numpy.nanprod

brainpy.math.jax.nanprod

numpy.nanquantile

brainpy.math.numpy.nanquantile

brainpy.math.jax.nanquantile

numpy.nanstd

brainpy.math.numpy.nanstd

brainpy.math.jax.nanstd

numpy.nansum

brainpy.math.numpy.nansum

brainpy.math.jax.nansum

numpy.nanvar

brainpy.math.numpy.nanvar

brainpy.math.jax.nanvar

numpy.ndfromtxt

-

-

numpy.ndim

brainpy.math.numpy.ndim

brainpy.math.jax.ndim

numpy.negative

brainpy.math.numpy.negative

brainpy.math.jax.negative

numpy.nested_iters

-

-

numpy.nextafter

brainpy.math.numpy.nextafter

brainpy.math.jax.nextafter

numpy.nonzero

brainpy.math.numpy.nonzero

brainpy.math.jax.nonzero

numpy.not_equal

brainpy.math.numpy.not_equal

brainpy.math.jax.not_equal

numpy.obj2sctype

-

-

numpy.ones

brainpy.math.numpy.ones

brainpy.math.jax.ones

numpy.ones_like

brainpy.math.numpy.ones_like

brainpy.math.jax.ones_like

numpy.outer

brainpy.math.numpy.outer

brainpy.math.jax.outer

numpy.packbits

-

-

numpy.pad

-

-

numpy.partition

-

-

numpy.percentile

brainpy.math.numpy.percentile

brainpy.math.jax.percentile

numpy.piecewise

-

-

numpy.place

-

-

numpy.poly

-

-

numpy.polyadd

-

-

numpy.polyder

-

-

numpy.polydiv

-

-

numpy.polyfit

-

-

numpy.polyint

-

-

numpy.polymul

-

-

numpy.polysub

-

-

numpy.polyval

-

-

numpy.positive

brainpy.math.numpy.positive

brainpy.math.jax.positive

numpy.power

brainpy.math.numpy.power

brainpy.math.jax.power

numpy.printoptions

-

-

numpy.prod

brainpy.math.numpy.prod

brainpy.math.jax.prod

numpy.product

-

-

numpy.promote_types

-

-

numpy.ptp

brainpy.math.numpy.ptp

brainpy.math.jax.ptp

numpy.put

-

-

numpy.put_along_axis

-

-

numpy.putmask

-

-

numpy.quantile

brainpy.math.numpy.quantile

brainpy.math.jax.quantile

numpy.rad2deg

brainpy.math.numpy.rad2deg

brainpy.math.jax.rad2deg

numpy.radians

brainpy.math.numpy.radians

brainpy.math.jax.radians

numpy.ravel

brainpy.math.numpy.ravel

brainpy.math.jax.ravel

numpy.ravel_multi_index

-

-

numpy.real

brainpy.math.numpy.real

brainpy.math.jax.real

numpy.real_if_close

-

-

numpy.recfromcsv

-

-

numpy.recfromtxt

-

-

numpy.reciprocal

brainpy.math.numpy.reciprocal

brainpy.math.jax.reciprocal

numpy.remainder

brainpy.math.numpy.remainder

brainpy.math.jax.remainder

numpy.repeat

brainpy.math.numpy.repeat

brainpy.math.jax.repeat

numpy.require

-

-

numpy.reshape

brainpy.math.numpy.reshape

brainpy.math.jax.reshape

numpy.resize

-

-

numpy.result_type

-

-

numpy.right_shift

brainpy.math.numpy.right_shift

brainpy.math.jax.right_shift

numpy.rint

brainpy.math.numpy.rint

brainpy.math.jax.rint

numpy.roll

brainpy.math.numpy.roll

brainpy.math.jax.roll

numpy.rollaxis

-

-

numpy.roots

-

-

numpy.rot90

-

-

numpy.round

brainpy.math.numpy.round

brainpy.math.jax.round

numpy.round_

brainpy.math.numpy.round_

brainpy.math.jax.round_

numpy.row_stack

-

-

numpy.safe_eval

-

-

numpy.save

-

-

numpy.savetxt

-

-

numpy.savez

-

-

numpy.savez_compressed

-

-

numpy.sctype2char

-

-

numpy.searchsorted

brainpy.math.numpy.searchsorted

brainpy.math.jax.searchsorted

numpy.select

brainpy.math.numpy.select

brainpy.math.jax.select

numpy.set_numeric_ops

-

-

numpy.set_printoptions

-

-

numpy.set_string_function

-

-

numpy.setbufsize

-

-

numpy.setdiff1d

-

-

numpy.seterr

-

-

numpy.seterrcall

-

-

numpy.seterrobj

-

-

numpy.setxor1d

-

-

numpy.shape

brainpy.math.numpy.shape

brainpy.math.jax.shape

numpy.shares_memory

-

-

numpy.show_config

-

-

numpy.sign

brainpy.math.numpy.sign

brainpy.math.jax.sign

numpy.signbit

brainpy.math.numpy.signbit

brainpy.math.jax.signbit

numpy.sin

brainpy.math.numpy.sin

brainpy.math.jax.sin

numpy.sinc

brainpy.math.numpy.sinc

brainpy.math.jax.sinc

numpy.sinh

brainpy.math.numpy.sinh

brainpy.math.jax.sinh

numpy.size

brainpy.math.numpy.size

brainpy.math.jax.size

numpy.sometrue

-

-

numpy.sort

brainpy.math.numpy.sort

brainpy.math.jax.sort

numpy.sort_complex

-

-

numpy.source

-

-

numpy.spacing

-

-

numpy.split

brainpy.math.numpy.split

brainpy.math.jax.split

numpy.sqrt

brainpy.math.numpy.sqrt

brainpy.math.jax.sqrt

numpy.square

brainpy.math.numpy.square

brainpy.math.jax.square

numpy.squeeze

brainpy.math.numpy.squeeze

brainpy.math.jax.squeeze

numpy.stack

brainpy.math.numpy.stack

brainpy.math.jax.stack

numpy.std

brainpy.math.numpy.std

brainpy.math.jax.std

numpy.subtract

brainpy.math.numpy.subtract

brainpy.math.jax.subtract

numpy.sum

brainpy.math.numpy.sum

brainpy.math.jax.sum

numpy.swapaxes

brainpy.math.numpy.swapaxes

brainpy.math.jax.swapaxes

numpy.take

brainpy.math.numpy.take

brainpy.math.jax.take

numpy.take_along_axis

brainpy.math.numpy.take_along_axis

brainpy.math.jax.take_along_axis

numpy.tan

brainpy.math.numpy.tan

brainpy.math.jax.tan

numpy.tanh

brainpy.math.numpy.tanh

brainpy.math.jax.tanh

numpy.tensordot

-

-

numpy.tile

brainpy.math.numpy.tile

brainpy.math.jax.tile

numpy.trace

brainpy.math.numpy.trace

brainpy.math.jax.trace

numpy.transpose

brainpy.math.numpy.transpose

brainpy.math.jax.transpose

numpy.trapz

brainpy.math.numpy.trapz

brainpy.math.jax.trapz

numpy.tri

brainpy.math.numpy.tri

brainpy.math.jax.tri

numpy.tril

brainpy.math.numpy.tril

brainpy.math.jax.tril

numpy.tril_indices

brainpy.math.numpy.tril_indices

brainpy.math.jax.tril_indices

numpy.tril_indices_from

brainpy.math.numpy.tril_indices_from

brainpy.math.jax.tril_indices_from

numpy.trim_zeros

-

-

numpy.triu

brainpy.math.numpy.triu

brainpy.math.jax.triu

numpy.triu_indices

brainpy.math.numpy.triu_indices

brainpy.math.jax.triu_indices

numpy.triu_indices_from

brainpy.math.numpy.triu_indices_from

brainpy.math.jax.triu_indices_from

numpy.true_divide

brainpy.math.numpy.true_divide

brainpy.math.jax.true_divide

numpy.trunc

brainpy.math.numpy.trunc

brainpy.math.jax.trunc

numpy.typename

-

-

numpy.union1d

-

-

numpy.unique

brainpy.math.numpy.unique

brainpy.math.jax.unique

numpy.unpackbits

-

-

numpy.unravel_index

-

-

numpy.unwrap

-

-

numpy.vander

brainpy.math.numpy.vander

brainpy.math.jax.vander

numpy.var

brainpy.math.numpy.var

brainpy.math.jax.var

numpy.vdot

brainpy.math.numpy.vdot

brainpy.math.jax.vdot

numpy.vsplit

brainpy.math.numpy.vsplit

brainpy.math.jax.vsplit

numpy.vstack

brainpy.math.numpy.vstack

brainpy.math.jax.vstack

numpy.where

brainpy.math.numpy.where

brainpy.math.jax.where

numpy.who

-

-

numpy.zeros

brainpy.math.numpy.zeros

brainpy.math.jax.zeros

numpy.zeros_like

brainpy.math.numpy.zeros_like

brainpy.math.jax.zeros_like

-

brainpy.math.numpy.clip_by_norm

brainpy.math.jax.clip_by_norm

-

brainpy.math.numpy.function

brainpy.math.jax.function

-

brainpy.math.numpy.set_complex_

brainpy.math.jax.set_complex_

-

brainpy.math.numpy.set_float_

brainpy.math.jax.set_float_

-

brainpy.math.numpy.set_int_

brainpy.math.jax.set_int_

Summary

  • Number of NumPy functions: 401

  • Number of functions covered by brainpy.math.numpy: 225

  • Number of functions covered by brainpy.math.jax: 225

Linear Algebra

NumPy

brainpy.math.numpy

brainpy.math.jax

numpy.linalg.cholesky

brainpy.math.numpy.linalg.cholesky

brainpy.math.jax.linalg.cholesky

numpy.linalg.cond

brainpy.math.numpy.linalg.cond

brainpy.math.jax.linalg.cond

numpy.linalg.det

brainpy.math.numpy.linalg.det

brainpy.math.jax.linalg.det

numpy.linalg.eig

brainpy.math.numpy.linalg.eig

brainpy.math.jax.linalg.eig

numpy.linalg.eigh

brainpy.math.numpy.linalg.eigh

brainpy.math.jax.linalg.eigh

numpy.linalg.eigvals

brainpy.math.numpy.linalg.eigvals

brainpy.math.jax.linalg.eigvals

numpy.linalg.eigvalsh

brainpy.math.numpy.linalg.eigvalsh

brainpy.math.jax.linalg.eigvalsh

numpy.linalg.inv

brainpy.math.numpy.linalg.inv

brainpy.math.jax.linalg.inv

numpy.linalg.lstsq

brainpy.math.numpy.linalg.lstsq

brainpy.math.jax.linalg.lstsq

numpy.linalg.matrix_power

brainpy.math.numpy.linalg.matrix_power

brainpy.math.jax.linalg.matrix_power

numpy.linalg.matrix_rank

brainpy.math.numpy.linalg.matrix_rank

brainpy.math.jax.linalg.matrix_rank

numpy.linalg.multi_dot

-

-

numpy.linalg.norm

brainpy.math.numpy.linalg.norm

brainpy.math.jax.linalg.norm

numpy.linalg.pinv

brainpy.math.numpy.linalg.pinv

brainpy.math.jax.linalg.pinv

numpy.linalg.qr

brainpy.math.numpy.linalg.qr

brainpy.math.jax.linalg.qr

numpy.linalg.slogdet

-

-

numpy.linalg.solve

-

-

numpy.linalg.svd

brainpy.math.numpy.linalg.svd

brainpy.math.jax.linalg.svd

numpy.linalg.tensorinv

-

-

numpy.linalg.tensorsolve

-

-

Summary

  • Number of NumPy functions: 20

  • Number of functions covered by brainpy.math.numpy: 15

  • Number of functions covered by brainpy.math.jax: 15

Discrete Fourier Transform

NumPy

brainpy.math.numpy

brainpy.math.jax

numpy.fft.fft

brainpy.math.numpy.fft.fft

brainpy.math.jax.fft.fft

numpy.fft.fft2

brainpy.math.numpy.fft.fft2

brainpy.math.jax.fft.fft2

numpy.fft.fftfreq

brainpy.math.numpy.fft.fftfreq

brainpy.math.jax.fft.fftfreq

numpy.fft.fftn

brainpy.math.numpy.fft.fftn

brainpy.math.jax.fft.fftn

numpy.fft.fftshift

brainpy.math.numpy.fft.fftshift

brainpy.math.jax.fft.fftshift

numpy.fft.hfft

brainpy.math.numpy.fft.hfft

brainpy.math.jax.fft.hfft

numpy.fft.ifft

brainpy.math.numpy.fft.ifft

brainpy.math.jax.fft.ifft

numpy.fft.ifft2

brainpy.math.numpy.fft.ifft2

brainpy.math.jax.fft.ifft2

numpy.fft.ifftn

brainpy.math.numpy.fft.ifftn

brainpy.math.jax.fft.ifftn

numpy.fft.ifftshift

brainpy.math.numpy.fft.ifftshift

brainpy.math.jax.fft.ifftshift

numpy.fft.ihfft

brainpy.math.numpy.fft.ihfft

brainpy.math.jax.fft.ihfft

numpy.fft.irfft

brainpy.math.numpy.fft.irfft

brainpy.math.jax.fft.irfft

numpy.fft.irfft2

brainpy.math.numpy.fft.irfft2

brainpy.math.jax.fft.irfft2

numpy.fft.irfftn

brainpy.math.numpy.fft.irfftn

brainpy.math.jax.fft.irfftn

numpy.fft.rfft

brainpy.math.numpy.fft.rfft

brainpy.math.jax.fft.rfft

numpy.fft.rfft2

brainpy.math.numpy.fft.rfft2

brainpy.math.jax.fft.rfft2

numpy.fft.rfftfreq

brainpy.math.numpy.fft.rfftfreq

brainpy.math.jax.fft.rfftfreq

numpy.fft.rfftn

brainpy.math.numpy.fft.rfftn

brainpy.math.jax.fft.rfftn

Summary

  • Number of NumPy functions: 18

  • Number of functions covered by brainpy.math.numpy: 18

  • Number of functions covered by brainpy.math.jax: 18

Random Sampling

NumPy

brainpy.math.numpy

brainpy.math.jax

numpy.random.beta

brainpy.math.numpy.random.beta

brainpy.math.jax.random.beta

numpy.random.binomial

-

-

numpy.random.bytes

-

-

numpy.random.chisquare

-

-

numpy.random.choice

brainpy.math.numpy.random.choice

brainpy.math.jax.random.choice

numpy.random.default_rng

-

-

numpy.random.dirichlet

-

-

numpy.random.exponential

brainpy.math.numpy.random.exponential

brainpy.math.jax.random.exponential

numpy.random.f

-

-

numpy.random.gamma

brainpy.math.numpy.random.gamma

brainpy.math.jax.random.gamma

numpy.random.geometric

-

-

numpy.random.get_state

-

-

numpy.random.gumbel

brainpy.math.numpy.random.gumbel

brainpy.math.jax.random.gumbel

numpy.random.hypergeometric

-

-

numpy.random.laplace

brainpy.math.numpy.random.laplace

brainpy.math.jax.random.laplace

numpy.random.logistic

brainpy.math.numpy.random.logistic

brainpy.math.jax.random.logistic

numpy.random.lognormal

-

-

numpy.random.logseries

-

-

numpy.random.multinomial

-

-

numpy.random.multivariate_normal

-

-

numpy.random.negative_binomial

-

-

numpy.random.noncentral_chisquare

-

-

numpy.random.noncentral_f

-

-

numpy.random.normal

brainpy.math.numpy.random.normal

brainpy.math.jax.random.normal

numpy.random.pareto

brainpy.math.numpy.random.pareto

brainpy.math.jax.random.pareto

numpy.random.permutation

brainpy.math.numpy.random.permutation

brainpy.math.jax.random.permutation

numpy.random.poisson

brainpy.math.numpy.random.poisson

brainpy.math.jax.random.poisson

numpy.random.power

-

-

numpy.random.rand

brainpy.math.numpy.random.rand

brainpy.math.jax.random.rand

numpy.random.randint

brainpy.math.numpy.random.randint

brainpy.math.jax.random.randint

numpy.random.randn

brainpy.math.numpy.random.randn

brainpy.math.jax.random.randn

numpy.random.random

brainpy.math.numpy.random.random

brainpy.math.jax.random.random

numpy.random.random_integers

-

-

numpy.random.random_sample

brainpy.math.numpy.random.random_sample

brainpy.math.jax.random.random_sample

numpy.random.ranf

brainpy.math.numpy.random.ranf

brainpy.math.jax.random.ranf

numpy.random.rayleigh

-

-

numpy.random.sample

brainpy.math.numpy.random.sample

brainpy.math.jax.random.sample

numpy.random.seed

brainpy.math.numpy.random.seed

brainpy.math.jax.random.seed

numpy.random.set_state

-

-

numpy.random.shuffle

brainpy.math.numpy.random.shuffle

brainpy.math.jax.random.shuffle

numpy.random.standard_cauchy

brainpy.math.numpy.random.standard_cauchy

brainpy.math.jax.random.standard_cauchy

numpy.random.standard_exponential

brainpy.math.numpy.random.standard_exponential

brainpy.math.jax.random.standard_exponential

numpy.random.standard_gamma

brainpy.math.numpy.random.standard_gamma

brainpy.math.jax.random.standard_gamma

numpy.random.standard_normal

brainpy.math.numpy.random.standard_normal

brainpy.math.jax.random.standard_normal

numpy.random.standard_t

brainpy.math.numpy.random.standard_t

brainpy.math.jax.random.standard_t

numpy.random.triangular

-

-

numpy.random.uniform

brainpy.math.numpy.random.uniform

brainpy.math.jax.random.uniform

numpy.random.vonmises

-

-

numpy.random.wald

-

-

numpy.random.weibull

-

-

numpy.random.zipf

-

-

-

brainpy.math.numpy.random.bernoulli

brainpy.math.jax.random.bernoulli

-

-

brainpy.math.jax.random.copy_doc

-

brainpy.math.numpy.random.numba_seed

-

-

brainpy.math.numpy.random.truncated_normal

brainpy.math.jax.random.truncated_normal

Summary

  • Number of NumPy functions: 51

  • Number of functions covered by brainpy.math.numpy: 26

  • Number of functions covered by brainpy.math.jax: 26

brainpy.integrators module

This module provides numerical solvers for various differential equations, including:

  • ordinary differential equations (ODEs)

  • stochastic differential equations (SDEs)

Details please see the following.

General Functions

This module provides numerical solvers for various differential equations, including:

  • ordinary differential equations (ODEs)

  • stochastic differential equations (SDEs)

Details please see the following.

odeint([f, method])

Numerical integration for ODEs.

sdeint([f, g, method])

Numerical integration for SDEs.

set_default_odeint(method)

Set the default ODE numerical integrator method for differential equations.

get_default_odeint()

Get the default ODE numerical integrator method.

set_default_sdeint(method)

Set the default SDE numerical integrator method for differential equations.

get_default_sdeint()

Get the default SDE numerical integrator method.

ODEIntegrator(f[, var_type, dt, name, show_code])

ODE Integrator.

SDEIntegrator(f, g[, dt, name, show_code, ...])

SDE Integrator.

class brainpy.integrators.ODEIntegrator(f, var_type=None, dt=None, name=None, show_code=False)[source]

ODE Integrator.

class brainpy.integrators.SDEIntegrator(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]

SDE Integrator.

Numerical Methods for ODEs

Numerical methods for ordinary differential equations.

Explicit Runge-Kutta Methods

This module provides explicit Runge-Kutta methods for ODEs.

Given an initial value problem specified as:

\[\frac{dy}{dt}=f(t,y),\quad y(t_{0})=y_{0}.\]

Let the step-size \(h > 0\).

Then, the general schema of explicit Runge–Kutta methods is [1]_:

\[y_{n+1}=y_{n}+h\sum _{i=1}^{s}b_{i}k_{i},\]

where

\[\begin{split}\begin{aligned} k_{1}&=f(t_{n},y_{n}),\\ k_{2}&=f(t_{n}+c_{2}h,y_{n}+h(a_{21}k_{1})),\\ k_{3}&=f(t_{n}+c_{3}h,y_{n}+h(a_{31}k_{1}+a_{32}k_{2})),\\ &\\ \vdots \\ k_{s}&=f(t_{n}+c_{s}h,y_{n}+h(a_{s1}k_{1}+a_{s2}k_{2}+\cdots +a_{s,s-1}k_{s-1})). \end{aligned}\end{split}\]

To specify a particular method, one needs to provide the integer \(s\) (the number of stages), and the coefficients \(a_{ij}\) (for \(1 \le j < i \le s\)), \(b_i\) (for \(i = 1, 2, \cdots, s\)) and \(c_i\) (for \(i = 2, 3, \cdots, s\)).

The matrix \([a_{ij}]\) is called the Runge–Kutta matrix, while the \(b_i\) and \(c_i\) are known as the weights and the nodes. These data are usually arranged in a mnemonic device, known as a Butcher tableau (named after John C. Butcher):

\[\begin{split}\begin{array}{c|llll} 0 & & & & & \\ c_{2} & a_{21} & & & & \\ c_{3} & a_{31} & a_{32} & & & \\ \vdots & \vdots & & \ddots & \\ c_{s} & a_{s 1} & a_{s 2} & \cdots & a_{s, s-1} \\ \hline & b_{1} & b_{2} & \cdots & b_{s-1} & b_{s} \end{array}\end{split}\]

A Taylor series expansion shows that the Runge–Kutta method is consistent if and only if

\[\sum _{i=1}^{s}b_{i}=1.\]

Another popular condition for determining coefficients is:

\[\sum_{j=1}^{i-1}a_{ij}=c_{i}{\text{ for }}i=2,\ldots ,s.\]

More details please see references [2]_ 3 4.

1

Press, W. H., B. P. Flannery, S. A. Teukolsky, and W. T. Vetterling. “Section 17.1 Runge-Kutta Method.” Numerical Recipes: The Art of Scientific Computing (2007).

2

https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods

3

Butcher, John Charles. Numerical methods for ordinary differential equations. John Wiley & Sons, 2016.

4

Iserles, A., 2009. A first course in the numerical analysis of differential equations (No. 44). Cambridge university press.

ExplicitRKIntegrator(f[, var_type, dt, ...])

Explicit Runge–Kutta methods for ordinary differential equation.

Euler(f[, var_type, dt, name, show_code])

The Euler method for ODEs.

MidPoint(f[, var_type, dt, name, show_code])

Explicit midpoint method for ODEs.

Heun2(f[, var_type, dt, name, show_code])

Heun's method for ODEs.

Ralston2(f[, var_type, dt, name, show_code])

Ralston's method for ODEs.

RK2(f[, beta, var_type, dt, name, show_code])

Generic second order Runge-Kutta method for ODEs.

RK3(f[, var_type, dt, name, show_code])

Classical third-order Runge-Kutta method for ODEs.

Heun3(f[, var_type, dt, name, show_code])

Heun's third-order method for ODEs.

Ralston3(f[, var_type, dt, name, show_code])

Ralston's third-order method for ODEs.

SSPRK3(f[, var_type, dt, name, show_code])

Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

RK4(f[, var_type, dt, name, show_code])

Classical fourth-order Runge-Kutta method for ODEs.

Ralston4(f[, var_type, dt, name, show_code])

Ralston's fourth-order method for ODEs.

RK4Rule38(f[, var_type, dt, name, show_code])

3/8-rule fourth-order method for ODEs.

class brainpy.integrators.ode.explicit_rk.ExplicitRKIntegrator(f, var_type=None, dt=None, name=None, show_code=False)[source]

Explicit Runge–Kutta methods for ordinary differential equation.

For the system,

\[\frac{d y}{d t}=f(t, y)\]

Explicit Runge-Kutta methods take the form

\[\begin{split}k_{i}=f\left(t_{n}+c_{i}h,y_{n}+h\sum _{j=1}^{s}a_{ij}k_{j}\right) \\ y_{n+1}=y_{n}+h \sum_{i=1}^{s} b_{i} k_{i}\end{split}\]

Each method listed on this page is defined by its Butcher tableau, which puts the coefficients of the method in a table as follows:

\[\begin{split}\begin{array}{c|cccc} c_{1} & a_{11} & a_{12} & \ldots & a_{1 s} \\ c_{2} & a_{21} & a_{22} & \ldots & a_{2 s} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ c_{s} & a_{s 1} & a_{s 2} & \ldots & a_{s s} \\ \hline & b_{1} & b_{2} & \ldots & b_{s} \end{array}\end{split}\]
Parameters
  • f (callable) – The derivative function.

  • show_code (bool) – Whether show the formatted code.

  • dt (float) – The numerical precision.

class brainpy.integrators.ode.explicit_rk.Euler(f, var_type=None, dt=None, name=None, show_code=False)[source]

The Euler method for ODEs.

Also named as Forward Euler method, or Explicit Euler method.

Given an ODE system,

\[y'(t)=f(t,y(t)),\qquad y(t_{0})=y_{0},\]

by using Euler method [1]_, we should choose a value \(h\) for the size of every step and set \(t_{n}=t_{0}+nh\). Now, one step of the Euler method from \(t_{n}\) to \(t_{n+1}=t_{n}+h\) is:

\[y_{n+1}=y_{n}+hf(t_{n},y_{n}).\]

Note that the method increments a solution through an interval \(h\) while using derivative information from only the beginning of the interval. As a result, the step’s error is \(O(h^2)\).

Geometric interpretation

Illustration of the Euler method. The unknown curve is in blue, and its polygonal approximation is in red [2]_:

../_static/ode_Euler_method.svg

Derivation

There are several ways to get Euler method [2]_.

The first is to consider the Taylor expansion of the function \(y\) around \(t_{0}\):

\[y(t_{0}+h)=y(t_{0})+hy'(t_{0})+{\frac {1}{2}}h^{2}y''(t_{0})+O(h^{3}).\]

where \(y'(t_0)=f(t_0,y)\). We ignore the quadratic and higher-order terms, then we get Euler method. The Taylor expansion is used below to analyze the error committed by the Euler method, and it can be extended to produce Runge–Kutta methods.

The second way is to replace the derivative with the forward finite difference formula:

\[y'(t_{0})\approx {\frac {y(t_{0}+h)-y(t_{0})}{h}}.\]

The third method is integrate the differential equation from \(t_{0}\) to \(t_{0}+h\) and apply the fundamental theorem of calculus to get:

\[y(t_{0}+h)-y(t_{0})=\int _{t_{0}}^{t_{0}+h}f(t,y(t))\,\mathrm {d} t \approx hf(t_{0},y(t_{0})).\]

Note

Euler method is a first order numerical procedure for solving ODEs with a given initial value. The lack of stability and accuracy limits its popularity mainly to use as a simple introductory example of a numeric solution method.

References

1

W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. Numerical Recipes in FORTRAN: The Art of Scientific Computing, 2nd ed. Cambridge, England: Cambridge University Press, p. 710, 1992.

2

https://en.wikipedia.org/wiki/Euler_method

class brainpy.integrators.ode.explicit_rk.MidPoint(f, var_type=None, dt=None, name=None, show_code=False)[source]

Explicit midpoint method for ODEs.

Also known as the modified Euler method [1]_.

The midpoint method is a one-step method for numerically solving the differential equation given by:

\[y'(t) = f(t, y(t)), \quad y(t_0) = y_0 .\]

The formula of the explicit midpoint method is:

\[y_{n+1} = y_n + hf\left(t_n+\frac{h}{2},y_n+\frac{h}{2}f(t_n, y_n)\right).\]

Therefore, the Butcher tableau of the midpoint method is:

\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 \\ \hline & 0 & 1 \end{array}\end{split}\]

Derivation

Compared to the slope formula of Euler method \(y'(t) \approx \frac{y(t+h) - y(t)}{h}\), the midpoint method use

\[y'\left(t+\frac{h}{2}\right) \approx \frac{y(t+h) - y(t)}{h},\]

The reason why we use this, please see the following geometric interpretation. Then, we get

\[y(t+h) \approx y(t) + hf\left(t+\frac{h}{2},y\left(t+\frac{h}{2}\right)\right).\]

However, we do not know \(y(t+h/2)\). The solution is then to use a Taylor series expansion exactly as the Euler method to solve:

\[y\left(t + \frac{h}{2}\right) \approx y(t) + \frac{h}{2}y'(t)=y(t) + \frac{h}{2}f(t, y(t)),\]

Finally, we can get the final step function:

\[y(t + h) \approx y(t) + hf\left(t + \frac{h}{2}, y(t) + \frac{h}{2}f(t, y(t))\right).\]

Geometric interpretation

In the basic Euler’s method, the tangent of the curve at \((t_{n},y_{n})\) is computed using \(f(t_{n},y_{n})\). The next value \(y_{n+1}\) is found where the tangent intersects the vertical line \(t=t_{n+1}\). However, if the second derivative is only positive between \(t_{n}\) and \(t_{n+1}\), or only negative, the curve will increasingly veer away from the tangent, leading to larger errors as \(h\) increases.

Compared with the Euler method, midpoint method use the tangent at the midpoint (upper, green line segment in the following figure [2]_), which would most likely give a more accurate approximation of the curve in that interval.

../_static/ode_Midpoint_method_illustration.png

Although this midpoint tangent could not be accurately calculated, we can estimate midpoint value of \(y(t)\) by using the original Euler’s method. Finally, the improved tangent is used to calculate the value of \(y_{n+1}\) from \(y_{n}\). This last step is represented by the red chord in the diagram.

Note

Note that the red chord is not exactly parallel to the green segment (the true tangent), due to the error in estimating the value of \(y(t)\) at the midpoint.

References

1

Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003.

2

https://en.wikipedia.org/wiki/Midpoint_method

class brainpy.integrators.ode.explicit_rk.Heun2(f, var_type=None, dt=None, name=None, show_code=False)[source]

Heun’s method for ODEs.

This method is named after Karl Heun [1]_. It is also known as the explicit trapezoid rule, improved Euler’s method, or modified Euler’s method.

Given ODEs with a given initial value,

\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]

the two-stage Heun’s method is formulated as:

\[\tilde{y}_{n+1} = y_n + h f(t_n,y_n)\]
\[y_{n+1} = y_n + \frac{h}{2}[f(t_n, y_n) + f(t_{n+1},\tilde{y}_{n+1})],\]

where \(h\) is the step size and \(t_{n+1}=t_n+h\).

Therefore, the Butcher tableau of the two-stage Heun’s method is:

\[\begin{split}\begin{array}{c|cc} 0.0 & 0.0 & 0.0 \\ 1.0 & 1.0 & 0.0 \\ \hline & 0.5 & 0.5 \end{array}\end{split}\]

Geometric interpretation

In the brainpy.integrators.ode.midpoint(), we have already known Euler method has big estimation error because it uses the line tangent to the function at the beginning of the interval \(t_n\) as an estimate of the slope of the function over the interval \((t_n, t_{n+1})\).

In order to address this problem, Heun’s Method considers the tangent lines to the solution curve at both ends of the interval (\(t_n\) and \(t_{n+1}\)), one (\(f(t_n, y_n)\)) which underestimates, and one (\(f(t_{n+1},\tilde{y}_{n+1})\), approximated using Euler’s Method) which overestimates the ideal vertical coordinates. The ideal point lies approximately halfway between the erroneous overestimation and underestimation, the average of the two slopes.

../_static/ode_Heun2_Method_Diagram.jpg
\[\begin{split}\begin{aligned} {\text{Slope}}_{\text{left}}=&f(t_{n},y_{n}) \\ {\text{Slope}}_{\text{right}}=&f(t_{n}+h,y_{n}+hf(t_{n},y_{n})) \\ {\text{Slope}}_{\text{ideal}}=&{\frac {1}{2}}({\text{Slope}}_{\text{left}}+{\text{Slope}}_{\text{right}}) \end{aligned}\end{split}\]

References

1

Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003.

class brainpy.integrators.ode.explicit_rk.Ralston2(f, var_type=None, dt=None, name=None, show_code=False)[source]

Ralston’s method for ODEs.

Ralston’s method is a second-order method with two stages and a minimum local error bound.

Given ODEs with a given initial value,

\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]

the Ralston’s second order method is given by

\[y_{n+1}=y_{n}+\frac{h}{4} f\left(t_{n}, y_{n}\right)+ \frac{3 h}{4} f\left(t_{n}+\frac{2 h}{3}, y_{n}+\frac{2 h}{3} f\left(t_{n}, y_{n}\right)\right)\]

Therefore, the corresponding Butcher tableau is:

\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ 2 / 3 & 2 / 3 & 0 \\ \hline & 1 / 4 & 3 / 4 \end{array}\end{split}\]
class brainpy.integrators.ode.explicit_rk.RK2(f, beta=0.6666666666666666, var_type=None, dt=None, name=None, show_code=False)[source]

Generic second order Runge-Kutta method for ODEs.

Derivation

In the brainpy.integrators.ode.midpoint(), brainpy.integrators.ode.heun2(), and brainpy.integrators.ode.ralston2(), we have already known first-order Euler method brainpy.integrators.ode.euler() has big estimation error.

Here, we seek to derive a generic second order Runge-Kutta method [1]_ for the given ODE system with a given initial value,

\[y'(t) = f(t,y(t)), \qquad y(t_0)=y_0,\]

we want to get a generic solution:

\[\begin{align} y_{n+1} &= y_{n} + h \left ( a_1 K_1 + a_2 K_2 \right ) \tag{1} \end{align}\]

where \(a_1\) and \(a_2\) are some weights to be determined, and \(K_1\) and \(K_2\) are derivatives on the form:

\[\begin{align} K_1 & = f(t_n,y_n) \qquad \text{and} \qquad K_2 = f(t_n + p_1 h,y_n + p_2 K_1 h ) \tag{2} \end{align}\]

By substitution of (2) in (1) we get:

\[\begin{align} y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h f(t_n + p_1 h,y_n + p_2 K_1 h) \tag{3} \end{align}\]

Now, we may find a Taylor-expansion of \(f(t_n + p_1 h, y_n + p_2 K_1 h )\)

\[\begin{split}\begin{align} f(t_n + p_1 h, y_n + p_2 K_1 h ) &= f + p_1 h f_t + p_2 K_1 h f_y + \text{h.o.t.} \nonumber \\ & = f + p_1 h f_t + p_2 h f f_y + \text{h.o.t.} \tag{4} \end{align}\end{split}\]

where \(f_t \equiv \frac{\partial f}{\partial t}\) and \(f_y \equiv \frac{\partial f}{\partial y}\).

By substitution of (4) in (3) we eliminate the implicit dependency of \(y_{n+1}\)

\[\begin{split}\begin{align} y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h \left (f + p_1 h f_t + p_2 h f f_y \right ) \nonumber \\ &= y_{n} + (a_1 + a_2) h f + \left (a_2 p_1 f_t + a_2 p_2 f f_y \right) h^2 \tag{5} \end{align}\end{split}\]

In the next, we try to get the second order Taylor expansion of the solution:

\[\begin{align} y(t_n+h) = y_n + h y' + \frac{h^2}{2} y'' + O(h^3) \tag{6} \end{align}\]

where the second order derivative is given by

\[\begin{align} y'' = \frac{d^2 y}{dt^2} = \frac{df}{dt} = \frac{\partial{f}}{\partial{t}} \frac{dt}{dt} + \frac{\partial{f}}{\partial{y}} \frac{dy}{dt} = f_t + f f_y \tag{7} \end{align}\]

Substitution of (7) into (6) yields:

\[\begin{align} y(t_n+h) = y_n + h f + \frac{h^2}{2} \left (f_t + f f_y \right ) + O(h^3) \tag{8} \end{align}\]

Finally, in order to approximate (8) by using (5), we get the generic second order Runge-Kutta method, where

\[\begin{split}\begin{aligned} a_1 + a_2 = 1 \\ a_2 p_1 = \frac{1}{2} \\ a_2 p_2 = \frac{1}{2}. \end{aligned}\end{split}\]

Furthermore, let \(p_1=\beta\), we get

\[\begin{split}\begin{aligned} p_1 = & \beta \\ p_2 = & \beta \\ a_2 = &\frac{1}{2\beta} \\ a_1 = &1 - \frac{1}{2\beta} . \end{aligned}\end{split}\]

Therefore, the corresponding Butcher tableau is:

\[\begin{split}\begin{array}{c|cc} 0 & 0 & 0 \\ \beta & \beta & 0 \\ \hline & 1 - {1 \over 2 * \beta} & {1 \over 2 * \beta} \end{array}\end{split}\]

References

1

Chapra, Steven C., and Raymond P. Canale. Numerical methods for engineers. Vol. 1221. New York: Mcgraw-hill, 2011.

class brainpy.integrators.ode.explicit_rk.RK3(f, var_type=None, dt=None, name=None, show_code=False)[source]

Classical third-order Runge-Kutta method for ODEs.

For the given initial value problem \(y'(x) = f(t,y);\, y(t_0) = y_0\), the third order Runge-Kutta method is given by:

\[y_{n+1} = y_n + 1/6 ( k_1 + 4 k_2 + k_3),\]

where

\[\begin{split}k_1 = h f(t_n, y_n), \\ k_2 = h f(t_n + h / 2, y_n + k_1 / 2), \\ k_3 = h f(t_n + h, y_n - k_1 + 2 k_2 ),\end{split}\]

where \(t_n = t_0 + n h.\)

Error term \(O(h^4)\), correct up to the third order term in Taylor series expansion.

The Taylor series expansion is \(y(t+h)=y(t)+\frac{k}{6}+\frac{2 k_{2}}{3}+\frac{k_{3}}{6}+O\left(h^{4}\right)\).

The corresponding Butcher tableau is:

\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 \\ 1 & -1 & 2 & 0 \\ \hline & 1 / 6 & 2 / 3 & 1 / 6 \end{array}\end{split}\]
class brainpy.integrators.ode.explicit_rk.Heun3(f, var_type=None, dt=None, name=None, show_code=False)[source]

Heun’s third-order method for ODEs.

It has the characteristics of:

  • method stage = 3

  • method order = 3

  • Butcher Tables:

\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 3 & 1 / 3 & 0 & 0 \\ 2 / 3 & 0 & 2 / 3 & 0 \\ \hline & 1 / 4 & 0 & 3 / 4 \end{array}\end{split}\]
class brainpy.integrators.ode.explicit_rk.Ralston3(f, var_type=None, dt=None, name=None, show_code=False)[source]

Ralston’s third-order method for ODEs.

It has the characteristics of:

  • method stage = 3

  • method order = 3

  • Butcher Tables:

\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 \\ 3 / 4 & 0 & 3 / 4 & 0 \\ \hline & 2 / 9 & 1 / 3 & 4 / 9 \end{array}\end{split}\]

References

1

Ralston, Anthony (1962). “Runge-Kutta Methods with Minimum Error Bounds”. Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0

class brainpy.integrators.ode.explicit_rk.SSPRK3(f, var_type=None, dt=None, name=None, show_code=False)[source]

Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

It has the characteristics of:

  • method stage = 3

  • method order = 3

  • Butcher Tables:

\[\begin{split}\begin{array}{c|ccc} 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 / 2 & 1 / 4 & 1 / 4 & 0 \\ \hline & 1 / 6 & 1 / 6 & 2 / 3 \end{array}\end{split}\]
class brainpy.integrators.ode.explicit_rk.RK4(f, var_type=None, dt=None, name=None, show_code=False)[source]

Classical fourth-order Runge-Kutta method for ODEs.

For the given initial value problem of

\[{\frac {dy}{dt}}=f(t,y),\quad y(t_{0})=y_{0}.\]

The fourth-order RK method is formulated as:

\[\begin{split}\begin{aligned} y_{n+1}&=y_{n}+{\frac {1}{6}}h\left(k_{1}+2k_{2}+2k_{3}+k_{4}\right),\\ t_{n+1}&=t_{n}+h\\ \end{aligned}\end{split}\]

for \(n = 0, 1, 2, 3, \cdot\), using

\[\begin{split}\begin{aligned} k_{1}&=\ f(t_{n},y_{n}),\\ k_{2}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{1}}{2}}\right),\\ k_{3}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{2}}{2}}\right),\\ k_{4}&=\ f\left(t_{n}+h,y_{n}+hk_{3}\right). \end{aligned}\end{split}\]

Here \(y_{n+1}\) is the RK4 approximation of \(y(t_{n+1})\), and the next value (\(y_{n+1}\)) is determined by the present value (\(y_{n}\)) plus the weighted average of four increments, where each increment is the product of the size of the interval, \(h\), and an estimated slope specified by function \(f\) on the right-hand side of the differential equation.

  • \(k_{1}\) is the slope at the beginning of the interval, using \(y\) (Euler’s method);

  • \(k_{2}\) is the slope at the midpoint of the interval, using \(y\) and \(k_{1}\);

  • \(k_{3}\) is again the slope at the midpoint, but now using \(y\) and \(k_{2}\);

  • \(k_{4}\) is the slope at the end of the interval, using \(y\) and \(k_{3}\).

The RK4 method is a fourth-order method, meaning that the local truncation error is on the order of (\(O(h^{5}\)), while the total accumulated error is on the order of (\(O(h^{4}\)).

The corresponding Butcher tableau is:

\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ 1 / 2 & 1 / 2 & 0 & 0 & 0 \\ 1 / 2 & 0 & 1 / 2 & 0 & 0 \\ 1 & 0 & 0 & 1 & 0 \\ \hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6 \end{array}\end{split}\]

References

1

Lambert, J. D. and Lambert, D. Ch. 5 in Numerical Methods for Ordinary Differential Systems: The Initial Value Problem. New York: Wiley, 1991.

2

Press, W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. “Runge-Kutta Method” and “Adaptive Step Size Control for Runge-Kutta.” §16.1 and 16.2 in Numerical Recipes in FORTRAN: The Art of Scientific Computing, 2nd ed. Cambridge, England: Cambridge University Press, pp. 704-716, 1992.

class brainpy.integrators.ode.explicit_rk.Ralston4(f, var_type=None, dt=None, name=None, show_code=False)[source]

Ralston’s fourth-order method for ODEs.

It has the characteristics of:

  • method stage = 4

  • method order = 4

  • Butcher Tables:

\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ .4 & .4 & 0 & 0 & 0 \\ .45573725 & .29697761 & .15875964 & 0 & 0 \\ 1 & .21810040 & -3.05096516 & 3.83286476 & 0 \\ \hline & .17476028 & -.55148066 & 1.20553560 & .17118478 \end{array}\end{split}\]

References

1

Ralston, Anthony (1962). “Runge-Kutta Methods with Minimum Error Bounds”. Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0

class brainpy.integrators.ode.explicit_rk.RK4Rule38(f, var_type=None, dt=None, name=None, show_code=False)[source]

3/8-rule fourth-order method for ODEs.

A slight variation of “the” Runge–Kutta method is also due to Kutta in 1901 [1]_ and is called the 3/8-rule. The primary advantage this method has is that almost all of the error coefficients are smaller than in the popular method, but it requires slightly more FLOPs (floating-point operations) per time step.

It has the characteristics of:

  • method stage = 4

  • method order = 4

  • Butcher Tables:

\[\begin{split}\begin{array}{c|cccc} 0 & 0 & 0 & 0 & 0 \\ 1 / 3 & 1 / 3 & 0 & 0 & 0 \\ 2 / 3 & -1 / 3 & 1 & 0 & 0 \\ 1 & 1 & -1 & 1 & 0 \\ \hline & 1 / 8 & 3 / 8 & 3 / 8 & 1 / 8 \end{array}\end{split}\]

References

1

Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving ordinary differential equations I: Nonstiff problems, Berlin, New York: Springer-Verlag, ISBN 978-3-540-56670-0.

Adaptive Runge-Kutta Methods

This module provides adaptive Runge-Kutta methods for ODEs.

Adaptive methods are designed to produce an estimate of the local truncation error of a single Runge–Kutta step. This is done by having two methods, one with order \(p\) and one with order \(p-1\). These methods are interwoven, i.e., they have common intermediate steps. Thanks to this, estimating the error has little or negligible computational cost compared to a step with the higher-order method.

During the integration, the step size is adapted such that the estimated error stays below a user-defined threshold: If the error is too high, a step is repeated with a lower step size; if the error is much smaller, the step size is increased to save time. This results in an (almost) optimal step size, which saves computation time. Moreover, the user does not have to spend time on finding an appropriate step size.

The lower-order step is given by

\[y_{n+1}^{*}=y_{n}+h\sum _{i=1}^{s}b_{i}^{*}k_{i},\]

where \(k_{i}\) are the same as for the higher-order method. Then the error is

\[e_{n+1}=y_{n+1}-y_{n+1}^{*}=h\sum _{i=1}^{s}(b_{i}-b_{i}^{*})k_{i},\]

which is (\(O(h^{p}\)).

The Butcher tableau for this kind of method is extended to give the values of \(b_{i}^{*}\):

\[\begin{split}\begin{array}{c|llll} 0 & & & & & \\ c_{2} & a_{21} & & & & \\ c_{3} & a_{31} & a_{32} & & & \\ \vdots & \vdots & & \ddots & \\ c_{s} & a_{s 1} & a_{s 2} & \cdots & a_{s, s-1} \\ \hline & b_{1} & b_{2} & \cdots & b_{s-1} & b_{s} \\ & b_{1}^{*} & b_{2}^{*} & \cdots & b_{s-1}^{*} & b_{s}^{*} \end{array}\end{split}\]

More details please check [1]_ [2]_ 3.

1

https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods

2

Press, W.H., Press, W.H., Flannery, B.P., Teukolsky, S.A., Vetterling, W.T., Flannery, B.P. and Vetterling, W.T., 1989. Numerical recipes in Pascal: the art of scientific computing (Vol. 1). Cambridge university press.

3

Press, W. H., & Teukolsky, S. A. (1992). Adaptive Stepsize Runge‐Kutta Integration. Computers in Physics, 6(2), 188-191.

AdaptiveRKIntegrator(f[, var_type, dt, ...])

Adaptive Runge-Kutta method for ordinary differential equations.

RKF12(f[, var_type, dt, name, adaptive, ...])

The Fehlberg RK1(2) method for ODEs.

RKF45(f[, var_type, dt, name, adaptive, ...])

The Runge–Kutta–Fehlberg method for ODEs.

DormandPrince(f[, var_type, dt, name, ...])

The Dormand–Prince method for ODEs.

CashKarp(f[, var_type, dt, name, adaptive, ...])

The Cash–Karp method for ODEs.

BogackiShampine(f[, var_type, dt, name, ...])

The Bogacki–Shampine method for ODEs.

HeunEuler(f[, var_type, dt, name, adaptive, ...])

The Heun–Euler method for ODEs.

class brainpy.integrators.ode.adaptive_rk.AdaptiveRKIntegrator(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

Adaptive Runge-Kutta method for ordinary differential equations.

The embedded methods are designed to produce an estimate of the local truncation error of a single Runge-Kutta step, and as result, allow to control the error with adaptive step-size. This is done by having two methods in the tableau, one with order p and one with order \(p-1\).

The lower-order step is given by

\[y^*_{n+1} = y_n + h\sum_{i=1}^s b^*_i k_i,\]

where the \(k_{i}\) are the same as for the higher order method. Then the error is

\[e_{n+1} = y_{n+1} - y^*_{n+1} = h\sum_{i=1}^s (b_i - b^*_i) k_i,\]

which is \(O(h^{p})\). The Butcher Tableau for this kind of method is extended to give the values of \(b_{i}^{*}\)

\[\begin{split}\begin{array}{c|cccc} c_1 & a_{11} & a_{12}& \dots & a_{1s}\\ c_2 & a_{21} & a_{22}& \dots & a_{2s}\\ \vdots & \vdots & \vdots& \ddots& \vdots\\ c_s & a_{s1} & a_{s2}& \dots & a_{ss} \\ \hline & b_1 & b_2 & \dots & b_s\\ & b_1^* & b_2^* & \dots & b_s^*\\ \end{array}\end{split}\]
Parameters
  • f (callable) – The derivative function.

  • show_code (bool) – Whether show the formatted code.

  • dt (float) – The numerical precision.

  • adaptive (bool) – Whether use the adaptive updating.

  • tol (float) – The error tolerence.

  • var_type (str) – The variable type.

class brainpy.integrators.ode.adaptive_rk.RKF12(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Fehlberg RK1(2) method for ODEs.

The Fehlberg method has two methods of orders 1 and 2.

It has the characteristics of:

  • method stage = 2

  • method order = 1

  • Butcher Tables:

\[\begin{split}\begin{array}{l|ll} 0 & & \\ 1 / 2 & 1 / 2 & \\ 1 & 1 / 256 & 255 / 256 & \\ \hline & 1 / 512 & 255 / 256 & 1 / 512 \\ & 1 / 256 & 255 / 256 & 0 \end{array}\end{split}\]

References

1

Fehlberg, E. (1969-07-01). “Low-order classical Runge-Kutta formulas with stepsize control and their application to some heat transfer problems”

class brainpy.integrators.ode.adaptive_rk.RKF45(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Runge–Kutta–Fehlberg method for ODEs.

The method presented in Fehlberg’s 1969 paper has been dubbed the RKF45 method, and is a method of order \(O(h^4)\) with an error estimator of order \(O(h^5)\). The novelty of Fehlberg’s method is that it is an embedded method from the Runge–Kutta family, meaning that identical function evaluations are used in conjunction with each other to create methods of varying order and similar error constants.

Its Butcher table is:

\[\begin{split}\begin{array}{l|lllll} 0 & & & & & & \\ 1 / 4 & 1 / 4 & & & & \\ 3 / 8 & 3 / 32 & 9 / 32 & & \\ 12 / 13 & 1932 / 2197 & -7200 / 2197 & 7296 / 2197 & \\ 1 & 439 / 216 & -8 & 3680 / 513 & -845 / 4104 & & \\ 1 / 2 & -8 / 27 & 2 & -3544 / 2565 & 1859 / 4104 & -11 / 40 & \\ \hline & 16 / 135 & 0 & 6656 / 12825 & 28561 / 56430 & -9 / 50 & 2 / 55 \\ & 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0 \end{array}\end{split}\]

References

1

https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method

2

Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step size control and their application to some heat transfer problems . NASA Technical Report 315. https://ntrs.nasa.gov/api/citations/19690021375/downloads/19690021375.pdf

class brainpy.integrators.ode.adaptive_rk.DormandPrince(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Dormand–Prince method for ODEs.

The DOPRI method, is an explicit method for solving ordinary differential equations (Dormand & Prince 1980). The Dormand–Prince method has seven stages, but it uses only six function evaluations per step because it has the FSAL (First Same As Last) property: the last stage is evaluated at the same point as the first stage of the next step. Dormand and Prince chose the coefficients of their method to minimize the error of the fifth-order solution. This is the main difference with the Fehlberg method, which was constructed so that the fourth-order solution has a small error. For this reason, the Dormand–Prince method is more suitable when the higher-order solution is used to continue the integration, a practice known as local extrapolation (Shampine 1986; Hairer, Nørsett & Wanner 2008, pp. 178–179).

Its Butcher table is:

\[\begin{split}\begin{array}{l|llllll} 0 & \\ 1 / 5 & 1 / 5 & & & \\ 3 / 10 & 3 / 40 & 9 / 40 & & & \\ 4 / 5 & 44 / 45 & -56 / 15 & 32 / 9 & & \\ 8 / 9 & 19372 / 6561 & -25360 / 2187 & 64448 / 6561 & -212 / 729 & \\ 1 & 9017 / 3168 & -355 / 33 & 46732 / 5247 & 49 / 176 & -5103 / 18656 & \\ 1 & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & \\ \hline & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & 0 \\ & 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40 \end{array}\end{split}\]

References

1

https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method

2

Dormand, J. R.; Prince, P. J. (1980), “A family of embedded Runge-Kutta formulae”, Journal of Computational and Applied Mathematics, 6 (1): 19–26, doi:10.1016/0771-050X(80)90013-3.

class brainpy.integrators.ode.adaptive_rk.CashKarp(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Cash–Karp method for ODEs.

The Cash–Karp method was proposed by Professor Jeff R. Cash from Imperial College London and Alan H. Karp from IBM Scientific Center. it uses six function evaluations to calculate fourth- and fifth-order accurate solutions. The difference between these solutions is then taken to be the error of the (fourth order) solution. This error estimate is very convenient for adaptive stepsize integration algorithms.

It has the characteristics of:

  • method stage = 6

  • method order = 4

  • Butcher Tables:

\[\begin{split}\begin{array}{l|lllll} 0 & & & & & & \\ 1 / 5 & 1 / 5 & & & & & \\ 3 / 10 & 3 / 40 & 9 / 40 & & & \\ 3 / 5 & 3 / 10 & -9 / 10 & 6 / 5 & & \\ 1 & -11 / 54 & 5 / 2 & -70 / 27 & 35 / 27 & & \\ 7 / 8 & 1631 / 55296 & 175 / 512 & 575 / 13824 & 44275 / 110592 & 253 / 4096 & \\ \hline & 37 / 378 & 0 & 250 / 621 & 125 / 594 & 0 & 512 / 1771 \\ & 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4 \end{array}\end{split}\]

References

1

https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method

2

J. R. Cash, A. H. Karp. “A variable order Runge-Kutta method for initial value problems with rapidly varying right-hand sides”, ACM Transactions on Mathematical Software 16: 201-222, 1990. doi:10.1145/79505.79507

class brainpy.integrators.ode.adaptive_rk.BogackiShampine(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Bogacki–Shampine method for ODEs.

The Bogacki–Shampine method was proposed by Przemysław Bogacki and Lawrence F. Shampine in 1989 (Bogacki & Shampine 1989). The Bogacki–Shampine method is a Runge–Kutta method of order three with four stages with the First Same As Last (FSAL) property, so that it uses approximately three function evaluations per step. It has an embedded second-order method which can be used to implement adaptive step size.

It has the characteristics of:

  • method stage = 4

  • method order = 3

  • Butcher Tables:

\[\begin{split}\begin{array}{l|lll} 0 & & & \\ 1 / 2 & 1 / 2 & & \\ 3 / 4 & 0 & 3 / 4 & \\ 1 & 2 / 9 & 1 / 3 & 4 / 9 \\ \hline & 2 / 9 & 1 / 3 & 4 / 90 \\ & 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8 \end{array}\end{split}\]

References

1

https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method

2

Bogacki, Przemysław; Shampine, Lawrence F. (1989), “A 3(2) pair of Runge–Kutta formulas”, Applied Mathematics Letters, 2 (4): 321–325, doi:10.1016/0893-9659(89)90079-7

class brainpy.integrators.ode.adaptive_rk.HeunEuler(f, var_type=None, dt=None, name=None, adaptive=None, tol=None, show_code=False)[source]

The Heun–Euler method for ODEs.

The simplest adaptive Runge–Kutta method involves combining Heun’s method, which is order 2, with the Euler method, which is order 1.

It has the characteristics of:

  • method stage = 2

  • method order = 1

  • Butcher Tables:

\[\begin{split}\begin{array}{c|cc} 0&\\ 1& 1 \\ \hline & 1/2& 1/2\\ & 1 & 0 \end{array}\end{split}\]

Exponential Integrators

This module provides exponential integrators for ODEs.

Exponential integrators are a large class of methods from numerical analysis is based on the exact integration of the linear part of the initial value problem. Because the linear part is integrated exactly, this can help to mitigate the stiffness of a differential equation.

We consider initial value problems of the form,

\[u'(t)=f(u(t)),\qquad u(t_{0})=u_{0},\]

which can be decomposed of

\[u'(t)=Lu(t)+N(u(t)),\qquad u(t_{0})=u_{0},\]

where \(L={\frac {\partial f}{\partial u}}\) (the Jacobian of f) is composed of linear terms, and \(N=f(u)-Lu\) is composed of the non-linear terms.

This procedure enjoys the advantage, in each step, that \({\frac {\partial N_{n}}{\partial u}}(u_{n})=0\). This considerably simplifies the derivation of the order conditions and improves the stability when integrating the nonlinearity \(N(u(t))\).

Exact integration of this problem from time 0 to a later time \(t\) can be performed using matrix exponentials to define an integral equation for the exact solution:

\[u(t)=e^{Lt}u_{0}+\int _{0}^{t}e^{L(t-\tau )}N\left(t+\tau, u\left(\tau \right)\right)\,d\tau .\]

This representation of the exact solution is also called as variation-of-constant formula. In the case of \(N\equiv 0\), this formulation is the exact solution to the linear differential equation.

Exponential Rosenbrock methods

Exponential Rosenbrock methods were shown to be very efficient in solving large systems of stiff ODEs. Applying the variation-of-constants formula gives the exact solution at time \(t_{n+1}\) with the numerical solution \(u_n\) as

(1)\[u(t_{n+1})=e^{h_{n}L}u(t_{n})+\int _{0}^{h_{n}}e^{(h_{n}-\tau )L}N(t_n+\tau, u(t_{n}+\tau ))d\tau .\]

where \(h_n=t_{n+1}-t_n\).

The idea now is to approximate the integral in (1) by some quadrature rule with nodes \(c_{i}\) and weights \(b_{i}(h_{n}L)\) (\(1\leq i\leq s\)). This yields the following class of s-stage explicit exponential Rosenbrock methods:

\[\begin{split}\begin{align} U_{ni}=&e^{c_{i}h_{n}L}u_n+h_{n}\sum_{j=1}^{i-1}a_{ij}(h_{n}L)N(U_{nj}), \\ u_{n+1}=&e^{h_{n}L}u_n+h_{n}\sum_{i=1}^{s}b_{i}(h_{n}L)N(U_{ni}) \end{align}\end{split}\]

where \(U_{ni}\approx u(t_{n}+c_{i}h_{n})\).

The coefficients \(a_{ij}(z),b_{i}(z)\) are usually chosen as linear combinations of the entire functions \(\varphi _{k}(c_{i}z),\varphi _{k}(z)\), respectively, where

\[\begin{split}\begin{align} \varphi _{k}(z)=&\int _{0}^{1}e^{(1-\theta )z}{\frac {\theta ^{k-1}}{(k-1)!}}d\theta ,\quad k\geq 1, \\ \varphi _{0}(z)=&e^{z},\\ \varphi _{k+1}(z)=&{\frac {\varphi_{k}(z)-\varphi _{k}(0)}{z}},\ k\geq 0. \end{align}\end{split}\]

By introducing the difference \(D_{ni}=N(U_{ni})-N(u_{n})\), they can be reformulated in a more efficient way for implementation as

\[\begin{split}\begin{align} U_{ni}=&u_{n}+c_{i}h_{n}\varphi _{1}(c_{i}h_{n}L)f(u_{n})+h_{n}\sum _{j=2}^{i-1}a_{ij}(h_{n}L)D_{nj}, \\ u_{n+1}=&u_{n}+h_{n}\varphi _{1}(h_{n}L)f(u_{n})+h_{n}\sum _{i=2}^{s}b_{i}(h_{n}L)D_{ni}. \end{align}\end{split}\]

where \(\varphi_{1}(z)=\frac{e^z-1}{z}\).

In order to implement this scheme with adaptive step size, one can consider, for the purpose of local error estimation, the following embedded methods

\[{\bar {u}}_{n+1}=u_{n}+h_{n}\varphi _{1}(h_{n}L)f(u_{n})+h_{n}\sum _{i=2}^{s}{\bar {b}}_{i}(h_{n}L)D_{ni},\]

which use the same stages \(U_{ni}\) but with weights \({\bar {b}}_{i}\).

For convenience, the coefficients of the explicit exponential Rosenbrock methods together with their embedded methods can be represented by using the so-called reduced Butcher tableau as follows:

\[\begin{split}\begin{array}{c|ccccc} c_{2} & & & & & \\ c_{3} & a_{32} & & & & \\ \vdots & \vdots & & \ddots & & \\ c_{s} & a_{s 2} & a_{s 3} & \cdots & a_{s, s-1} \\ \hline & b_{2} & b_{3} & \cdots & b_{s-1} & b_{s} \\ & \bar{b}_{2} & \bar{b}_{3} & \cdots & \bar{b}_{s-1} & \bar{b}_{s} \end{array}\end{split}\]
1

https://en.wikipedia.org/wiki/Exponential_integrator

2

Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.

ExponentialEuler(f[, var_type, dt, name, ...])

The exponential Euler method for ODEs.

class brainpy.integrators.ode.exponential.ExponentialEuler(f, var_type=None, dt=None, name=None, show_code=False)[source]

The exponential Euler method for ODEs.

The simplest exponential Rosenbrock method is the exponential Rosenbrock–Euler scheme, which has order 2.

For an ODE equation of the form

\[u^{\prime}=f(u), \quad u(0)=u_{0}\]

its schema is given by

\[u_{n+1}= u_{n}+h \varphi(hL) f (u_{n})\]

where \(L=f^{\prime}(u_{n})\) and \(\varphi(z)=\frac{e^{z}-1}{z}\).

For a linear ODE system: \(u^{\prime} = Ay + B\), the above equation is equal to \(u_{n+1}= u_{n}e^{hA}-B/A(1-e^{hA})\), which is the exact solution for this ODE system.

Parameters
  • f (function) – The derivative function.

  • dt (optional, float) – The numerical precision.

  • var_type (optional, str) – The variable type.

  • show_code (bool) – Whether show the code.

Error Analysis of Numerical Methods

In order to identify the essential properties of numerical methods, we define basic notions 1.

For the given ODE system

\[\frac{dy}{dt}=f(t,y),\quad y(t_{0})=y_{0},\]

we define \(y(t_n)\) as the solution of IVP evaluated at \(t=t_n\), and \(y_n\) is a numerical approximation of \(y(t_n)\) at the same location by a generic explicit numerical scheme (no matter explicit, implicit or multi-step scheme):

\[\begin{align} y_{n+1} = y_n + h \phi(t_n,y_n,h), \tag{2} \end{align}\]

where \(h\) is the discretization step for \(t\), i.e., \(h=t_{n+1}-t_n\), and \(\phi(t_n,y_n,h)\) is the increment function. We say that the defined numerical scheme is consistent if \(\lim_{h\to0} \phi(t,y,h) = \phi(t,y,0) = f(t,y)\).

Then, the approximation error is defined as

\[e_n = y(t_n) - y_n.\]

The absolute error is defined as

\[|e_n| = |y(t_n) - y_n|.\]

The relative error is defined as

\[r_n =\frac{|y(t_n) - y_n|}{|y(t_n)|}.\]

The exact differential operator is defined as

\[\begin{align} L_e(y) = y' - f(t,y) = 0 \end{align}\]

The approximate differential operator is defined as

\[\begin{align} L_a(y_n) = y(t_{n+1}) - [y_n + \phi(t_n,y_n,h)]. \end{align}\]

Finally, the local truncation error (LTE) is defined as

\[\begin{align} \tau_n = \frac{1}{h} L_a(y(x_n)). \end{align}\]

In practice, the evaluation of the exact solution for different \(t\) around \(t_n\) (required by \(L_a\)) is performed using a Taylor series expansion.

Finally, we can state that a scheme is \(p\)-th order accurate by examining its LTE and observing its leading term

\[\begin{align} \tau_n = C h^p + H.O.T., \end{align}\]

where \(C\) is a constant, independent of \(h\), and \(H.O.T.\) are the higher order terms of the LTE.

Example: LTE for Euler’s scheme

Consider the IVP defined by \(y' = \lambda y\), with initial condition \(y(0)=1\).

The approximation operator for Euler’s scheme is

\[\begin{align} L^{euler}_a = y(t_{n+1}) - [y_n + h \lambda y_n], \end{align}\]

then the LTE can be computed by

\[\begin{split}\begin{align} \tau_n = & \frac{1}{h}\left\{ L_a(y(t_n))\right\} = \frac{1}{h}\left\{ y(t_{n+1}) - [y(t_n) + h \lambda y(t_n)]\right\}, \\ = & \frac{1}{h}\left\{ y(t_n) + h y'(t_n) + \frac{h^2}{2} y''(t_n) + \ldots + \frac{1}{p!} h^p y^{(p)}(t_n) - y(t_n) - h \lambda y(t_n) \right\} \\ = & \frac{1}{2} h y''(t_n) + \ldots + \frac{1}{p!} h^{p-1} y^{(p)}(t_n) \\ \approx & \frac{1}{2} h y''(t_n), \end{align}\end{split}\]

where we assume \(y_n = y(t_n)\).

1

https://folk.ntnu.no/leifh/teaching/tkt4140/._main022.html

Numerical Methods for SDEs

Numerical methods for stochastic differential equations.

Euler(f, g[, dt, name, show_code, var_type, ...])

Heun(f, g[, dt, name, show_code, var_type, ...])

Milstein(f, g[, dt, name, show_code, ...])

ExponentialEuler(f, g[, dt, name, ...])

First order, explicit exponential Euler method.

SRK1W1(f, g[, dt, name, show_code, ...])

Order 2.0 weak SRK methods for SDEs with scalar Wiener process.

SRK2W1(f, g[, dt, name, show_code, ...])

Order 1.5 Strong SRK Methods for SDEs witdt Scalar Noise.

KlPl(f, g[, dt, name, show_code, var_type, ...])

class brainpy.integrators.sde.Euler(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
class brainpy.integrators.sde.Heun(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
class brainpy.integrators.sde.Milstein(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]
class brainpy.integrators.sde.ExponentialEuler(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]

First order, explicit exponential Euler method.

For a SDE equation of the form

\[d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \quad y(0)=y_{0}\]

its schema is given by [1]_

\[\begin{split}y_{n+1} & =e^{\Delta t A}(y_{n}+ g(y_n)\Delta W_{n})+\varphi(\Delta t A) F(y_{n}) \Delta t \\ &= y_n + \Delta t \varphi(\Delta t A) f(y) + e^{\Delta t A}g(y_n)\Delta W_{n}\end{split}\]

where \(\varphi(z)=\frac{e^{z}-1}{z}\).

References

1

Erdoğan, Utku, and Gabriel J. Lord. “A new class of exponential integrators for stochastic differential equations with multiplicative noise.” arXiv preprint arXiv:1608.07096 (2016).

class brainpy.integrators.sde.SRK1W1(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]

Order 2.0 weak SRK methods for SDEs with scalar Wiener process.

This method has have strong orders :backend:`(p_d, p_s) = (2.0,1.5)`.

The Butcher table is:

\[\begin{split}\begin{array}{l|llll|llll|llll} 0 &&&&& &&&& &&&& \\ 3/4 &3/4&&&& 3/2&&& &&&& \\ 0 &0&0&0&& 0&0&0&& &&&&\\ \hline 0 \\ 1/4 & 1/4&&& & 1/2&&&\\ 1 & 1&0&&& -1&0&\\ 1/4& 0&0&1/4&& -5&3&1/2\\ \hline & 1/3& 2/3& 0 & 0 & -1 & 4/3 & 2/3&0 & -1 &4/3 &-1/3 &0 \\ \hline & &&&& 2 &-4/3 & -2/3 & 0 & -2 & 5/3 & -2/3 & 1 \end{array}\end{split}\]

References

1

Rößler, Andreas. “Strong and weak approximation methods for stochastic differential equations—some recent developments.” Recent developments in applied probability and statistics. Physica-Verlag HD, 2010. 127-153.

2

Rößler, Andreas. “Runge–Kutta methods for the strong approximation of solutions of stochastic differential equations.” SIAM Journal on Numerical Analysis 48.3 (2010): 922-952.

class brainpy.integrators.sde.SRK2W1(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]

Order 1.5 Strong SRK Methods for SDEs witdt Scalar Noise.

This method has have strong orders :backend:`(p_d, p_s) = (3.0,1.5)`.

The Butcher table is:

\[egin{array}{c|cccc|cccc|ccc|} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 / 2 & 1 / 4 & 1 / 4 & 0 & 0 & 1 & 1 / 2 & 0 & 0 & & & & \ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ \hline 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \ 1 / 4 & 1 / 4 & 0 & 0 & 0 & -1 / 2 & 0 & 0 & 0 & & & & \ 1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & & & & \ 1 / 4 & 0 & 0 & 1 / 4 & 0 & 2 & -1 & 1 / 2 & 0 & & & & \ \hline & 1 / 6 & 1 / 6 & 2 / 3 & 0 & -1 & 4 / 3 & 2 / 3 & 0 & -1 & -4 / 3 & 1 / 3 & 0 \ \hline & & & & &2 & -4 / 3 & -2 / 3 & 0 & -2 & 5 / 3 & -2 / 3 & 1 \end{array}\]

References

[1] Rößler, Andreas. “Strong and weak approximation methods for stochastic differential

equations—some recent developments.” Recent developments in applied probability and statistics. Physica-Verlag HD, 2010. 127-153.

[2] Rößler, Andreas. “Runge–Kutta methods for the strong approximation of solutions of

stochastic differential equations.” SIAM Journal on Numerical Analysis 48.3 (2010): 922-952.

class brainpy.integrators.sde.KlPl(f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None)[source]

brainpy.simulation module

This module provides APIs for brain simulations.

Brain Objects

This module provides various interface to model brain objects. You can access them through brainpy.XXX or brainpy.brainobjects.XXX.

BrainArea(*ds_tuple[, steps, monitors, name])

DynamicalSystem([steps, monitors, name])

Base Dynamical System class.

Container(*ds_tuple[, steps, monitors, name])

Container object which is designed to add other instances of DynamicalSystem.

Delay([steps, name, monitors])

Base class to model delay variables.

ConstantDelay(size, delay[, dtype, dt])

Class used to model constant delay variables.

SpikeTimeInput(size, times, indices[, need_sort])

The input neuron group characterized by spikes emitting at given times.

PoissonInput(size, freqs[, seed])

Poisson Neuron Group.

Molecular(name, **kwargs)

Base class to model molecular objects.

Network(*ds_tuple[, monitors, name])

Base class to model network objects, an alias of Container.

NeuGroup(size[, name, steps])

Base class to model neuronal groups.

Channel(**kwargs)

Base class to model ion channels.

Soma(name, **kwargs)

Base class to model soma in multi-compartment neuron models.

Dendrite(name, **kwargs)

Base class to model dendrites.

TwoEndConn(pre, post[, conn, name, steps])

Base class to model two-end synaptic connections.

class brainpy.simulation.brainobjects.BrainArea(*ds_tuple, steps=None, monitors=None, name=None, **ds_dict)[source]
class brainpy.simulation.brainobjects.DynamicalSystem(steps=None, monitors=None, name=None)[source]

Base Dynamical System class.

Any object has step functions will be a dynamical system. That is to say, in BrainPy, the essence of the dynamical system is the “step functions”.

Parameters
  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

register_constant_delay(key, size, delay, dtype=None)[source]

Register a constant delay.

Parameters
  • key (str) – The delay name.

  • size (int, list of int, tuple of int) – The delay data size.

  • delay (int, float, ndarray) – The delay time, with the unit same with brainpy.math.get_dt().

  • dtype (optional) – The data type.

Returns

delay – An instance of ConstantDelay.

Return type

ConstantDelay

run(duration, dt=None, report=0.0, inputs=(), extra_func=None)[source]

The running function.

Parameters
  • inputs (list, tuple) –

    The inputs for this instance of DynamicalSystem. It should the format of [(target, value, [type, operation])], where target is the input target, value is the input value, type is the input type (such as “fix” or “iter”), operation is the operation for inputs (such as “+”, “-”, “*”, “/”, “=”).

    • target: should be a string. Can be specified by the absolute access or relative access.

    • value: should be a scalar, vector, matrix, iterable function or objects.

    • type: should be a string. “fix” means the input value is a constant. “iter” means the input value can be changed over time.

    • operation: should be a string, support +, -, *, /, =.

    • Also, if you want to specify multiple inputs, just give multiple (target, value, [type, operation]), for example [(target1, value1), (target2, value2)].

  • duration (float, int, tuple, list) – The running duration.

  • report (float) – The percent of progress to report. [0, 1]. If zero, the model will not output report progress.

  • dt (float, optional) – The numerical integration step size.

  • extra_func (function, callable) – The extra function to run during each time step.

  • method (optional, str) – The method to run the model.

Returns

running_time – The total running time.

Return type

float

update(*args, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.brainobjects.Container(*ds_tuple, steps=None, monitors=None, name=None, **ds_dict)[source]

Container object which is designed to add other instances of DynamicalSystem.

Parameters
  • steps (tuple of function, tuple of str, dict of (str, function), optional) – The step functions.

  • monitors (tuple, list, Monitor, optional) – The monitor object.

  • name (str, optional) – The object name.

  • show_code (bool) – Whether show the formatted code.

  • ds_dict (dict of (str, )) – The instance of DynamicalSystem with the format of “key=dynamic_system”.

update(_t, _dt)[source]

Step function of a network.

In this update function, the step functions in children systems are iteratively called.

class brainpy.simulation.brainobjects.Delay(steps=('update',), name=None, monitors=None)[source]

Base class to model delay variables.

Parameters
  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(_t, _dt, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.brainobjects.ConstantDelay(size, delay, dtype=None, dt=None, **kwargs)[source]

Class used to model constant delay variables.

This class automatically supports batch size on the last axis. For example, if you run batch with the size of (10, 100), where 100 are batch size, then this class can automatically support your batched data.

For examples:

>>> import brainpy as bp
>>>
>>> bp.ConstantDelay(size=10, delay=10.)
>>> bp.ConstantDelay(size=100, delay=bp.math.random.random(100) * 4 + 10)
Parameters
  • size (int, list of int, tuple of int) – The delay data size.

  • delay (int, float, function, ndarray) – The delay time. With the unit of dt.

  • num_batch (optional, int) – The batch size.

  • steps (optional, tuple of str, tuple of function, dict of (str, function)) – The callable function, or a list of callable functions.

  • monitors (optional, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (optional, str) – The name of the dynamic system.

reset()[source]

Reset the variables.

update(_t, _dt, **kwargs)[source]

Update the delay index.

class brainpy.simulation.brainobjects.SpikeTimeInput(size, times, indices, need_sort=True, **kwargs)[source]

The input neuron group characterized by spikes emitting at given times.

>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
>>> SpikeTimeInput(2, times=[10, 20])
>>> # or
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
>>> SpikeTimeInput(2, times=[10, 20], indices=[0, 0])
>>> # or
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
>>> SpikeTimeInput(2, times=[10, 20, 30], indices=[0, 1, 0])
>>> # or
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
>>> # at 30 ms, neuron 1 fires.
>>> SpikeTimeInput(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
Parameters
  • size (int, tuple, list) – The neuron group geometry.

  • indices (int, list, tuple) – The neuron indices at each time point to emit spikes.

  • times (list, np.ndarray) – The time points which generate the spikes.

  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(_t, _i, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.brainobjects.PoissonInput(size, freqs, seed=None, **kwargs)[source]

Poisson Neuron Group.

Parameters
  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(_t, _i, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.brainobjects.Molecular(name, **kwargs)[source]

Base class to model molecular objects.

Parameters
  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

class brainpy.simulation.brainobjects.Network(*ds_tuple, monitors=None, name=None, **ds_dict)[source]

Base class to model network objects, an alias of Container.

Network instantiates a network, which is aimed to load neurons, synapses, and other brain objects.

Parameters
  • name (str, Optional) – The network name.

  • monitors (optional, list of str, tuple of str) – The items to monitor.

  • ds_tuple – A list/tuple container of dynamical system.

  • ds_dict – A dict container of dynamical system.

class brainpy.simulation.brainobjects.NeuGroup(size, name=None, steps=('update',), **kwargs)[source]

Base class to model neuronal groups.

There are several essential attributes:

  • size: the geometry of the neuron group. For example, (10, ) denotes a line of neurons, (10, 10) denotes a neuron group aligned in a 2D space, (10, 15, 4) denotes a 3-dimensional neuron group.

  • num: the flattened number of neurons in the group. For example, size=(10, ) => num=10, size=(10, 10) => num=100, size=(10, 15, 4) => num=600.

  • shape: the variable shape with (num, num_batch).

Parameters
  • size (int, tuple of int, list of int) – The neuron group geometry.

  • num_batch (optional, int) – The batch size.

  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The step functions.

  • steps – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (optional, str) – The name of the dynamic system.

update(_t, _dt)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.brainobjects.Channel(**kwargs)[source]

Base class to model ion channels.

Notes

The __init__() function in Channel is used to specify the parameters of the channel. The __call__() function is used to initialize the variables in this channel.

class brainpy.simulation.brainobjects.Soma(name, **kwargs)[source]

Base class to model soma in multi-compartment neuron models.

class brainpy.simulation.brainobjects.Dendrite(name, **kwargs)[source]

Base class to model dendrites.

class brainpy.simulation.brainobjects.TwoEndConn(pre, post, conn=None, name=None, steps=('update',), **kwargs)[source]

Base class to model two-end synaptic connections.

Parameters
  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The step functions.

  • pre (NeuGroup) – Pre-synaptic neuron group.

  • post (NeuGroup) – Post-synaptic neuron group.

  • conn (math.ndarray, dict of (str, math.ndarray), TwoEndConnector) – The connection method between pre- and post-synaptic groups.

  • steps – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

  • show_code (bool) – Whether show the formatted code.

DNN Layers

This module provides various interfaces to model DNN layers. You can access them through brainpy.layers.XXX.

Module([steps, monitors, name])

Basic module class for DNN networks.

Sequential(*arg_ds[, monitors, name])

Basic sequential object to control data flow.

Activation(activation[, name])

Activation Layer.

Conv2D(num_input, num_output, kernel_size[, ...])

Apply a 2D convolution on a 4D-input batch of shape (N,C,H,W).

Dense(num_hidden, num_input[, w, b])

A fully connected layer implemented as the dot product of inputs and weights.

Dropout(prob[, seed])

A layer that stochastically ignores a subset of inputs each training step.

LinearReadout(num_hidden, num_input[, ...])

Neuron group to readout information linearly.

RNNCore(num_hidden, num_input, **kwargs)

VanillaRNN(num_hidden, num_input, num_batch)

Basic fully-connected RNN core.

GRU(num_hidden, num_input, num_batch[, wx, ...])

Gated Recurrent Unit.

LSTM(num_hidden, num_input, num_batch[, w, ...])

Long short-term memory (LSTM) RNN core.

class brainpy.simulation.layers.Module(steps=None, monitors=None, name=None)[source]

Basic module class for DNN networks.

target_backend = 'jax'

Used to specify the target backend which the model to run.

class brainpy.simulation.layers.Sequential(*arg_ds, monitors=None, name=None, **kwarg_ds)[source]

Basic sequential object to control data flow.

Parameters
  • arg_ds – The modules without name specifications.

  • name (str, optional) – The name of the sequential module.

  • kwarg_ds – The modules with name specifications.

update(*args, **kwargs)[source]

Functional call.

Parameters
  • args (list, tuple) – The *args arguments.

  • kwargs (dict) – The config arguments. The configuration used across modules. If the “__call__” function in submodule receives “config” arguments, This “config” parameter will be passed into this function.

class brainpy.simulation.layers.Activation(activation, name=None, **setting)[source]

Activation Layer.

Parameters
  • activation (str) – The name of the activation function.

  • name (optional, str) – The name of the class.

  • setting (Any) – The settings for the activation function.

update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.Conv2D(num_input, num_output, kernel_size, strides=1, dilations=1, groups=1, padding='SAME', w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]

Apply a 2D convolution on a 4D-input batch of shape (N,C,H,W).

Parameters
  • num_input (int) – The number of channels of the input tensor.

  • num_output (int) – The number of channels of the output tensor.

  • kernel_size (int, tuple of int) – The size of the convolution kernel, either tuple (height, width) or single number if they’re the same.

  • strides (int, tuple of int) – The convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.

  • dilations (int, tuple of int) – The spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.

  • groups (int) – The number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.

  • padding (int, str) – The padding of the input tensor, either “SAME”, “VALID” or numerical values (low, high).

  • w (Initializer, JaxArray, jax.numpy.ndarray) – The initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).

  • b (Initializer, JaxArray, jax.numpy.ndarray, optional) – The bias initialization.

  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.Dense(num_hidden, num_input, w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]

A fully connected layer implemented as the dot product of inputs and weights.

Parameters
  • num_hidden (int) – The neuron group size.

  • num_input (int) – The input size.

  • w (Initializer, JaxArray, jax.numpy.ndarray) – Initializer for the weights.

  • b (Initializer, JaxArray, jax.numpy.ndarray, optional) – Initializer for the bias.

  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, datastructures.Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(x)[source]

Returns the results of applying the linear transformation to input x.

class brainpy.simulation.layers.Dropout(prob, seed=None, **kwargs)[source]

A layer that stochastically ignores a subset of inputs each training step.

In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).

The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.

This layer is active only during training (mode=’train’). In other circumstances it is a no-op.

Originally introduced in the paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting” available under the following link: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf

Parameters
  • prob (float) – Probability to keep element of the tensor.

  • steps (tuple of str, tuple of function, dict of (str, function), optional) – The callable function, or a list of callable functions.

  • monitors (None, list, tuple, Monitor) – Variables to monitor.

  • name (str, optional) – The name of the dynamic system.

update(x, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.LinearReadout(num_hidden, num_input, num_batch=1, w_init=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b_init=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, has_bias=True, s_init=<brainpy.simulation.initialize.random_inits.Uniform object>, train_mask=None, **kwargs)[source]

Neuron group to readout information linearly.

Parameters
  • num_hidden (int) – The neuron group size.

  • num_input (int) – The input size.

  • w_init (Initializer) – Initializer for the weights.

  • b_init (Initializer) – Initializer for the bias.

  • has_bias (bool) – Whether has the bias to compute.

  • s_init (Initializer) – Initializer for variable states.

  • train_mask (optional, math.ndarray) – The training mask for the weights.

update(x, **kwargs)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.RNNCore(num_hidden, num_input, **kwargs)[source]
abstract update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.VanillaRNN(num_hidden, num_input, num_batch, h=<brainpy.simulation.initialize.random_inits.Uniform object>, w=<brainpy.simulation.initialize.random_inits.XavierNormal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]

Basic fully-connected RNN core.

Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes

\[h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]

The output is equal to the new state, \(h_t\).

update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.GRU(num_hidden, num_input, num_batch, wx=<brainpy.simulation.initialize.random_inits.Orthogonal object>, wh=<brainpy.simulation.initialize.random_inits.Orthogonal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, h=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]

Gated Recurrent Unit.

The implementation is based on (Chung, et al., 2014) [1]_ with biases.

Given \(x_t\) and the previous state \(h_{t-1}\) the core computes

\[\begin{split}\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\end{split}\]

where \(z_t\) and \(r_t\) are reset and update gates.

The output is equal to the new hidden state, \(h_t\).

Warning: Backwards compatibility of GRU weights is currently unsupported.

References

1

Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.

update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

class brainpy.simulation.layers.LSTM(num_hidden, num_input, num_batch, w=<brainpy.simulation.initialize.random_inits.Orthogonal object>, b=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, hc=<brainpy.simulation.initialize.regular_inits.ZeroInit object>, **kwargs)[source]

Long short-term memory (LSTM) RNN core.

The implementation is based on (zaremba, et al., 2014) [1]_. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]

where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden, \(h_t\).

Notes

Forget gate initialization: Following (Jozefowicz, et al., 2015) 2 we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

References

1

Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. “Recurrent neural network regularization.” arXiv preprint arXiv:1409.2329 (2014).

2

Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. “An empirical exploration of recurrent network architectures.” In International conference on machine learning, pp. 2342-2350. PMLR, 2015.

update(x)[source]

The function to specify the updating rule.

Parameters
  • _t (float) – The current time.

  • _dt (float) – The time step.

Synaptic Connectivity

This module provides methods to construct connectivity between neuron groups. You can access them through brainpy.connect.XXX.

Base Class

Connector()

Base Synaptical Connector Class.

TwoEndConnector()

Synaptical connector to build synapse connections between two neuron groups.

OneEndConnector()

Synaptical connector to build synapse connections within a population of neurons.

class brainpy.simulation.connect.Connector[source]

Base Synaptical Connector Class.

class brainpy.simulation.connect.TwoEndConnector[source]

Synaptical connector to build synapse connections between two neuron groups.

class brainpy.simulation.connect.OneEndConnector[source]

Synaptical connector to build synapse connections within a population of neurons.

Custom Connections

MatConn(conn_mat)

Connector built from the connection matrix.

IJConn(i, j)

Connector built from the pre_ids and post_ids connections.

class brainpy.simulation.connect.MatConn(conn_mat)[source]

Connector built from the connection matrix.

class brainpy.simulation.connect.IJConn(i, j)[source]

Connector built from the pre_ids and post_ids connections.

Random Connections

FixedProb(prob[, include_self, seed])

Connect the post-synaptic neurons with fixed probability.

FixedPreNum(num[, include_self, seed])

Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.

FixedPostNum(num[, include_self, seed])

Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.

GaussianProb(sigma[, encoding_values, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function.

SmallWorld(num_neighbor, prob[, directed, ...])

Build a Watts–Strogatz small-world graph.

ScaleFreeBA(m[, directed, seed])

Build a random graph according to the Barabási–Albert preferential attachment model.

ScaleFreeBADual(m1, m2, p[, directed, seed])

Build a random graph according to the dual Barabási–Albert preferential attachment model.

PowerLaw(m, p[, directed, seed])

Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering.

class brainpy.simulation.connect.FixedProb(prob, include_self=True, seed=None)[source]

Connect the post-synaptic neurons with fixed probability.

Parameters
  • prob (float) – The conn probability.

  • include_self (bool) – Whether create (i, i) conn?

  • seed (optional, int) – Seed the random generator.

class brainpy.simulation.connect.FixedPreNum(num, include_self=True, seed=None)[source]

Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.

Parameters
  • num (float, int) – The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self (bool) – Whether create (i, i) conn ?

  • seed (None, int) – Seed the random generator.

  • method (str) –

    The method used to create the connection.

    • matrix: This method will create a big matrix, then, the connectivity is constructed from this matrix \((N_{pre}, N_{post})\). In a large network, this method will consume huge memories, including a matrix: \((N_{pre}, N_{post})\), two vectors: \(2 * N_{need} * N_{post}\).

    • iter: This method will iteratively build the synaptic connections. It has the minimum pressure of memory consuming, only \(2 * N_{need} * N_{post}\) (i and j vectors).

class brainpy.simulation.connect.FixedPostNum(num, include_self=True, seed=None)[source]

Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.

Parameters
  • num (float, int) – The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self (bool) – Whether create (i, i) conn ?

  • seed (None, int) – Seed the random generator.

  • method (str) –

    The method used to create the connection.

    • matrix: This method will create a big matrix, then, the connectivity is constructed from this matrix \((N_{pre}, N_{post})\). In a large network, this method will consume huge memories, including a matrix: \((N_{pre}, N_{post})\), two vectors: \(2 * N_{need} * N_{pre}\).

    • iter: This method will iteratively build the synaptic connections. It has the minimum pressure of memory consuming, only \(2 * N_{need} * N_{pre}\) (i and j vectors).

class brainpy.simulation.connect.GaussianProb(sigma, encoding_values=None, normalize=True, include_self=True, periodic_boundary=False, seed=None)[source]

Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function.

Specifically, for any pair of neurons \((i, j)\),

\[p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2})\]

where \(v_k^i\) is the $i$-th neuron’s encoded value at dimension $k$.

Parameters
  • sigma (float) – Width of the Gaussian function.

  • encoding_values (optional, list, tuple, int, float) –

    The value ranges to encode for neurons at each axis.

    • If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.

    • If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].

    • If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).

  • periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.

  • normalize (bool) – Whether normalize the connection probability .

  • include_self (bool) – Whether create the conn at the same position.

  • seed (bool) – The random seed.

class brainpy.simulation.connect.SmallWorld(num_neighbor, prob, directed=False, include_self=False)[source]

Build a Watts–Strogatz small-world graph.

Parameters
  • num_neighbor (int) – Each node is joined with its k nearest neighbors in a ring topology.

  • prob (float) – The probability of rewiring each edge

  • directed (bool) – Whether the graph is a directed graph.

  • include_self (bool) – Whether include the node self.

Notes

First create a ring over \(num\_node\) nodes [1]_. Then each node in the ring is joined to its \(num\_neighbor\) nearest neighbors (or \(num\_neighbor - 1\) neighbors if \(num\_neighbor\) is odd). Then shortcuts are created by replacing some edges as follows: for each edge \((u, v)\) in the underlying “\(num\_node\)-ring with \(num\_neighbor\) nearest neighbors” with probability \(prob\) replace it with a new edge \((u, w)\) with uniformly random choice of existing node \(w\).

References

1

Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.

class brainpy.simulation.connect.ScaleFreeBA(m, directed=False, seed=None)[source]

Build a random graph according to the Barabási–Albert preferential attachment model.

A graph of \(num\_node\) nodes is grown by attaching new nodes each with \(m\) edges that are preferentially attached to existing nodes with high degree.

Parameters
  • m (int) – Number of edges to attach from a new node to existing nodes

  • seed (integer, random_state, or None (default)) – Indicator of random number generation state.

Raises

ValueError – If m does not satisfy 1 <= m < n.

References

1

A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.

class brainpy.simulation.connect.ScaleFreeBADual(m1, m2, p, directed=False, seed=None)[source]

Build a random graph according to the dual Barabási–Albert preferential attachment model.

A graph of :math::num_node nodes is grown by attaching new nodes each with either $m_1$ edges (with probability \(p\)) or \(m_2\) edges (with probability \(1-p\)) that are preferentially attached to existing nodes with high degree.

Parameters
  • m1 (int) – Number of edges to attach from a new node to existing nodes with probability $p$

  • m2 (int) – Number of edges to attach from a new node to existing nodes with probability $1-p$

  • p (float) – The probability of attaching $m_1$ edges (as opposed to $m_2$ edges)

  • seed (integer, random_state, or None (default)) – Indicator of random number generation state.

Raises

ValueError – If m1 and m2 do not satisfy 1 <= m1,m2 < n or p does not satisfy 0 <= p <= 1.

References

1
  1. Moshiri “The dual-Barabasi-Albert model”, arXiv:1810.10538.

class brainpy.simulation.connect.PowerLaw(m, p, directed=False, seed=None)[source]

Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering.

Parameters
  • m (int) – the number of random edges to add for each new node

  • p (float,) – Probability of adding a triangle after adding a random edge

  • seed (integer, random_state, or None (default)) – Indicator of random number generation state.

Notes

The average clustering has a hard time getting above a certain cutoff that depends on \(m\). This cutoff is often quite low. The transitivity (fraction of triangles to possible triangles) seems to decrease with network size.

It is essentially the Barabási–Albert (BA) growth model with an extra step that each random edge is followed by a chance of making an edge to one of its neighbors too (and thus a triangle).

This algorithm improves on BA in the sense that it enables a higher average clustering to be attained if desired.

It seems possible to have a disconnected graph with this algorithm since the initial \(m\) nodes may not be all linked to a new node on the first iteration like the BA model.

Raises

ValueError – If \(m\) does not satisfy \(1 <= m <= n\) or \(p\) does not satisfy \(0 <= p <= 1\).

References

1

P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.

Regular Connections

One2One()

Connect two neuron groups one by one.

All2All([include_self])

Connect each neuron in first group to all neurons in the post-synaptic neuron groups.

GridFour([include_self])

The nearest four neighbors conn method.

GridEight([include_self])

The nearest eight neighbors conn method.

GridN([N, include_self])

The nearest (2*N+1) * (2*N+1) neighbors conn method.

class brainpy.simulation.connect.One2One[source]

Connect two neuron groups one by one. This means The two neuron groups should have the same size.

class brainpy.simulation.connect.All2All(include_self=True)[source]

Connect each neuron in first group to all neurons in the post-synaptic neuron groups. It means this kind of conn will create (num_pre x num_post) synapses.

class brainpy.simulation.connect.GridFour(include_self=False)[source]

The nearest four neighbors conn method.

class brainpy.simulation.connect.GridEight(include_self=False)[source]

The nearest eight neighbors conn method.

class brainpy.simulation.connect.GridN(N=1, include_self=False)[source]

The nearest (2*N+1) * (2*N+1) neighbors conn method.

Parameters
  • N (int) –

    Extend of the conn scope. For example: When N=1,

    [x x x] [x I x] [x x x]

    When N=2,

    [x x x x x] [x x x x x] [x x I x x] [x x x x x] [x x x x x]

  • include_self (bool) – Whether create (i, i) conn ?

Formatter Functions

ij2mat(i, j[, num_pre, num_post])

Convert i-j connection to matrix connection.

mat2ij(conn_mat)

Get the i-j connections from connectivity matrix.

pre2post(i, j[, num_pre])

Get pre2post connections from i and j indexes.

post2pre(i, j[, num_post])

Get post2pre connections from i and j indexes.

pre2syn(i[, num_pre])

Get pre2syn connections from i and j indexes.

post2syn(j[, num_post])

Get post2syn connections from i and j indexes.

pre_slice(i, j[, num_pre])

Get post slicing connections by pre-synaptic ids.

post_slice(i, j[, num_post])

Get pre slicing connections by post-synaptic ids.

Weight Initialization

This module provides methods to initialize weights. You can access them through brainpy.initialize.XXX.

Base Class

Initializer()

Base Initialization Class.

class brainpy.simulation.initialize.Initializer[source]

Base Initialization Class.

Regular Initializers

ZeroInit()

Zero initializer.

OneInit([value])

One initializer.

Identity([value])

Returns the identity matrix.

class brainpy.simulation.initialize.ZeroInit[source]

Zero initializer.

Initialize the weights with zeros.

class brainpy.simulation.initialize.OneInit(value=1.0)[source]

One initializer.

Initialize the weights with the given values.

Parameters

value (float, int, math.ndarray) – The value to specify.

class brainpy.simulation.initialize.Identity(value=1.0)[source]

Returns the identity matrix.

This initializer was proposed in (Le, et al., 2015) 1.

Parameters

value (float) – The optional scaling factor.

Returns

shape – The weight shape/size.

Return type

tuple of int

References

1

Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. “A simple way to initialize recurrent networks of rectified linear units.” arXiv preprint arXiv:1504.00941 (2015).

Random Initializers

Normal([scale])

Initialize weights with normal distribution.

Uniform([min_val, max_val, scale])

Initialize weights with uniform distribution.

VarianceScaling(scale, mode, distribution[, ...])

KaimingUniform([scale, mode, distribution, ...])

KaimingNormal([scale, mode, distribution, ...])

XavierUniform([scale, mode, distribution, ...])

XavierNormal([scale, mode, distribution, ...])

LecunUniform([scale, mode, distribution, ...])

LecunNormal([scale, mode, distribution, ...])

Orthogonal([scale, axis])

Construct an initializer for uniformly distributed orthogonal matrices.

DeltaOrthogonal([scale, axis])

Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.

class brainpy.simulation.initialize.Normal(scale=1.0)[source]

Initialize weights with normal distribution.

Parameters

gain (float) – The gain of the derivation of the normal distribution.

class brainpy.simulation.initialize.Uniform(min_val=0.0, max_val=1.0, scale=0.01)[source]

Initialize weights with uniform distribution.

Parameters
  • min_val (float) – The lower limit of the uniform distribution.

  • max_val (float) – The upper limit of the uniform distribution.

class brainpy.simulation.initialize.VarianceScaling(scale, mode, distribution, in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.KaimingUniform(scale=2.0, mode='fan_in', distribution='uniform', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.KaimingNormal(scale=2.0, mode='fan_in', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.XavierUniform(scale=1.0, mode='fan_avg', distribution='uniform', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.XavierNormal(scale=1.0, mode='fan_avg', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.LecunUniform(scale=1.0, mode='fan_in', distribution='uniform', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.LecunNormal(scale=1.0, mode='fan_in', distribution='truncated_normal', in_axis=-2, out_axis=-1)[source]
class brainpy.simulation.initialize.Orthogonal(scale=1.0, axis=-1)[source]

Construct an initializer for uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

class brainpy.simulation.initialize.DeltaOrthogonal(scale=1.0, axis=-1)[source]

Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.

The shape must be 3D, 4D or 5D.

Decay Initializers

GaussianDecay(sigma, max_w[, min_w, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function.

DOGDecay(sigmas, max_ws[, min_w, ...])

Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons.

class brainpy.simulation.initialize.GaussianDecay(sigma, max_w, min_w=None, encoding_values=None, periodic_boundary=False, include_self=True, normalize=False)[source]

Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function.

Specifically, for any pair of neurons \((i, j)\), the weight is computed as

\[w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2})\]

where \(v_k^i\) is the $i$-th neuron’s encoded value at dimension $k$.

Parameters
  • sigma (float) – Width of the Gaussian function.

  • max_w (float) – The weight amplitude of the Gaussian function.

  • min_w (float, None) – The minimum weight value below which synapses are not created (default: \(0.005 * max\_w\)).

  • include_self (bool) – Whether create the conn at the same position.

  • encoding_values (optional, list, tuple, int, float) –

    The value ranges to encode for neurons at each axis.

    • If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.

    • If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].

    • If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).

  • periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.

  • normalize (bool) – Whether normalize the connection probability.

class brainpy.simulation.initialize.DOGDecay(sigmas, max_ws, min_w=None, encoding_values=None, periodic_boundary=False, normalize=True, include_self=True)[source]

Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons.

Mathematically, for the given pair of neurons \((i, j)\), the weight between them is computed as

\[w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2})\]

where weights smaller than \(0.005 * max(w_{max}, w_{min})\) are not created and self-connections are avoided by default (parameter allow_self_connections).

Parameters
  • sigmas (tuple) – Widths of the positive and negative Gaussian functions.

  • max_ws (tuple) – The weight amplitudes of the positive and negative Gaussian functions.

  • min_ws (float, None) – The minimum weight value below which synapses are not created (default: \(0.005 * max(max\_ws)\)).

  • include_self (bool) – Whether create the conn at the same position.

  • normalize (bool) – Whether normalize the connection probability .

  • encoding_values (optional, list, tuple, int, float) –

    The value ranges to encode for neurons at each axis.

    • If values is not provided, the neuron only encodes each positional information, i.e., \((i, j, k, ...)\), where \(i, j, k\) is the index in the high-dimensional space.

    • If values is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, values=(0, np.pi), neurons at each dimension will encode a continuous value space [0, np.pi].

    • If values is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi)).

  • periodic_boundary (bool) – Whether the neuron encode the value space with the periodic boundary.

Current Inputs

This module provides various methods to form current inputs. You can access them through brainpy.inputs.XXX.

section_input(values, durations[, dt, ...])

Format an input current with different sections.

constant_input(I_and_duration[, dt])

Format constant input in durations.

constant_current(I_and_duration[, dt])

Format constant input in durations.

spike_input(sp_times, sp_lens, sp_sizes, ...)

Format current input like a series of short-time spikes.

spike_current(sp_times, sp_lens, sp_sizes, ...)

Format current input like a series of short-time spikes.

ramp_input(c_start, c_end, duration[, ...])

Get the gradually changed input current.

ramp_current(c_start, c_end, duration[, ...])

Get the gradually changed input current.

Measurements

This module aims to provide commonly used analysis methods for simulated neuronal data. You can access them through brainpy.measure.XXX.

cross_correlation(spikes, bin[, dt])

Calculate cross correlation index between neurons.

voltage_fluctuation(potentials)

Calculate neuronal synchronization via voltage variance.

raster_plot(sp_matrix, times)

Get spike raster plot which displays the spiking activity of a group of neurons over time.

firing_rate(sp_matrix, width[, dt])

Calculate the mean firing rate over in a neuron group.

Monitors

Monitor(variables[, intervals, target])

The basic Monitor class to store the past variable trajectories.

class brainpy.simulation.monitor.Monitor(variables, intervals=None, target=None)[source]

The basic Monitor class to store the past variable trajectories.

Currently, brainpy.simulation.Monitor support to specify:

  • variable key by strings.

  • variable index by None, int, list, tuple, 1D array/tensor (==> all will be transformed into a 1D array/tensor)

  • variable monitor interval by None, int, float

Users can instance a monitor object by multiple ways:

  1. list of strings.

>>> Monitor(target=..., variables=['a', 'b', 'c'])

1.1. list of strings and list of intervals

>>> Monitor(target=..., variables=['a', 'b', 'c'],
>>>         intervals=[None, 1, 2] # ms
>>>        )
  1. list of strings and string + indices

>>> Monitor(target=..., variables=['a', ('b', math.array([1,2,3])), 'c'])

2.1. list of string (+ indices) and list of intervals

>>> Monitor(target=..., variables=['a', ('b', math.array([1,2,3])), 'c'],
>>>         intervals=[None, 2, 3])
  1. a dictionary with the format of {key: indices}

>>> Monitor(target=..., variables={'a': None, 'b': math.array([1,2,3])})

3.1. a dictionaly of variable and indexes, and a dictionary of time intervals

>>> Monitor(target=..., variables={'a': None, 'b': math.array([1,2,3])},
>>>         intervals={'b': 2.})

Note

brainpy.simulation.Monitor records any target variable with an two-dimensional array/tensor with the shape of (num_time_step, variable_size). This means for any variable, no matter what’s the shape of the data (int, float, vector, matrix, 3D array/tensor), will be reshaped into a one-dimensional vector.

brainpy.analysis module

This module provides analysis tools for differential equations.

  • The symbolic module use SymPy symbolic inference to make analysis of low-dimensional dynamical system (only sypport ODEs).

  • The numeric module use numerical optimization function to make analysis of high-dimensional dynamical system (support ODEs and discrete systems).

  • The continuation module is the analysis package with numerical continuation methods.

  • Moreover, we provide several useful functions in stability module which may help your dynamical system analysis, like:

    >>> get_1d_stability_types()
    ['saddle node', 'stable point', 'unstable point']
    

Details in the following.

Dynamics Analysis (Symbolic)

Dynamics analysis with the aid of SymPy symbolic inference.

This module provide basic dynamics analysis for low-dimensional dynamical systems, including

  • phase plane analysis (1d or 2d systems)

  • bifurcation analysis (1d or 2d systems)

  • fast slow bifurcation analysis (2d or 3d systems)

BaseSymAnalyzer(model_or_integrals, target_vars)

Dynamics Analyzer for Neuron Models.

Base1DSymAnalyzer(*args, **kwargs)

Neuron analysis analyzer for 1D system.

Base2DSymAnalyzer(*args, **kwargs)

Neuron analysis analyzer for 2D system.

Bifurcation(integrals, target_pars, target_vars)

A tool class for bifurcation analysis.

FastSlowBifurcation(integrals, fast_vars, ...)

Fast slow analysis analysis proposed by John Rinzel 1 2 3.

PhasePlane(model, target_vars[, fixed_vars, ...])

A tool class for phase plane analysis.

class brainpy.analysis.symbolic.BaseSymAnalyzer(model_or_integrals, target_vars, fixed_vars=None, target_pars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]

Dynamics Analyzer for Neuron Models.

This class is a base class which aims for analyze the analysis in neuron models. A neuron model is characterized by a series of dynamical variables and parameters:

\[{dF \over dt} = F(v_1, v_2, ..., p_1, p_2, ...)\]

where \(v_1, v_2\) are variables, \(p_1, p_2\) are parameters.

Parameters
  • model_or_integrals (Any) – A model of the population, the integrator function, or a list/tuple of integrator functions.

  • target_vars (dict) – The target/dynamical variables.

  • fixed_vars (dict) – The fixed variables.

  • target_pars (dict, optional) – The parameters which can be dynamical varied.

  • pars_update (dict, optional) – The parameters to update.

  • numerical_resolution (float, dict) –

    The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example,

    • set numerical_resolution=0.1 will generalize it to all variables and parameters;

    • set numerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05} will specify the particular resolutions to variables and parameters.

    • Moreover, you can also set numerical_resolution={var1: np.array([...]), var2: 0.1} to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.

  • options (dict, optional) –

    The other setting parameters, which includes:

    • perturbation: float. The small perturbation used to solve the function derivative.

    • sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.

    • escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.

    • lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to [var_min * (1-lim_scale)/2, var_max * (var_max-1)/2].

class brainpy.analysis.symbolic.Base1DSymAnalyzer(*args, **kwargs)[source]

Neuron analysis analyzer for 1D system.

It supports the analysis of 1D dynamical system.

\[{dx \over dt} = f(x, t)\]
get_f_dfdx(origin=True)[source]

Get the derivative of f by variable x.

get_f_dx()[source]

Get the derivative function of the first variable.

get_f_fixed_point()[source]

Get the function to solve the fixed point.

class brainpy.analysis.symbolic.Base2DSymAnalyzer(*args, **kwargs)[source]

Neuron analysis analyzer for 2D system.

It supports the analysis of 2D dynamical system.

\[ \begin{align}\begin{aligned}{dx \over dt} = f(x, t, y)\\{dy \over dt} = g(y, t, x)\end{aligned}\end{align} \]
Parameters

options (dict, optional) –

The other setting parameters, which includes:

  • shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.

  • show_shgo: bool. whether print the shgo’s value.

  • fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.

  • xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.

get_f_dfdy(origin=True)[source]

Get the derivative of f by variable y.

get_f_dgdx(origin=True)[source]

Get the derivative of g by variable x.

get_f_dgdy(origin=True)[source]

Get the derivative of g by variable y.

get_f_dy()[source]

Get the derivative function of the second variable.

get_f_fixed_point()[source]

Get the function to solve the fixed point.

get_f_jacobian()[source]

Get the function to solve jacobian matrix.

get_f_optimize_x_nullcline(coords=None)[source]

Get the function to solve X nullcline by using numerical optimization method.

Parameters

coords (str) – The coordination.

get_f_optimize_y_nullcline(coords=None)[source]

Get the function to solve Y nullcline by using numerical optimization method.

Parameters

coords (str) – The coordination.

get_x_by_y_in_x_eq()[source]

Get the expression of “x_by_y_in_x_eq”.

Specifically, self.analyzed_results['x_by_y_in_x_eq'] is a Dict, with the following keywords:

  • status : ‘sympy_success’, ‘sympy_failed’, ‘escape’

  • subs : substituted expressions (relationship) of x_by_y

  • f : function of x_by_y

get_x_by_y_in_y_eq()[source]

Get the expression of “x_by_y_in_y_eq”.

Specifically, self.analyzed_results['x_by_y_in_y_eq'] is a Dict, with the following keywords:

  • status : ‘sympy_success’, ‘sympy_failed’, ‘escape’

  • subs : substituted expressions (relationship) of x_by_y

  • f : function of x_by_y

get_y_by_x_in_x_eq()[source]

Get the expression of “y_by_x_in_x_eq”.

Specifically, self.analyzed_results['y_by_x_in_x_eq'] is a Dict, with the following keywords:

  • status : ‘sympy_success’, ‘sympy_failed’, ‘escape’

  • subs : substituted expressions (relationship) of y_by_x

  • f : function of y_by_x

get_y_by_x_in_y_eq()[source]

Get the expression of “y_by_x_in_y_eq”.

Specifically, self.analyzed_results['y_by_x_in_y_eq'] is a Dict, with the following keywords:

  • status : ‘sympy_success’, ‘sympy_failed’, ‘escape’

  • subs : substituted expressions (relationship) of y_by_x

  • f : function of y_by_x

class brainpy.analysis.symbolic.Bifurcation(integrals, target_pars, target_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]

A tool class for bifurcation analysis.

The bifurcation analyzer is restricted to analyze the bifurcation relation between membrane potential and a given model parameter (co-dimension-1 case) or two model parameters (co-dimension-2 case).

Externally injected current is also treated as a model parameter in this class, instead of a model state.

Examples

Parameters
  • integrals (function, functions) – The integral functions defined with brainpy.odeint or brainpy.sdeint or brainpy.ddeint, or brainpy.fdeint.

  • target_vars (dict) – The target dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.

  • fixed_vars (dict, optional) – The fixed variables. It must a fixed value with the format of {‘var1’: value}.

  • target_pars (dict, optional) – The parameters which can be dynamical varied. It must be a dictionary which specifies the boundary of the variables: {‘par1’: [min, max]}

  • pars_update (dict, optional) – The parameters to update. Or, they can be treated as staitic parameters. Same with the fixed_vars, they are must fixed values with the format of {‘par1’: value}.

  • numerical_resolution (float, dict, optional) –

    The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example,

    • set numerical_resolution=0.1 will generalize it to all variables and parameters;

    • set numerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05} will specify the particular resolutions to variables and parameters.

    • Moreover, you can also set numerical_resolution={var1: np.array([...]), var2: 0.1} to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.

  • options (dict, optional) –

    The other setting parameters, which includes:

    • perturbation: float. The small perturbation used to solve the function derivatives.

    • sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.

    • escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.

    • lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to [var_min * (1-lim_scale)/2, var_max * (var_max-1)/2].

    The parameters which are usefull for two-dimensional bifurcation analysis:

    • shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.

    • show_shgo: bool. whether print the shgo’s value.

    • fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.

    • xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.

plot_bifurcation(*args, **kwargs)[source]

Plot bifurcation, which support bifurcation analysis of co-dimension 1 and co-dimension 2.

Parameters

show (bool) – Whether show the bifurcation figure.

Returns

points – The bifurcation points which specifies their fixed points and corresponding stability.

Return type

dict

plot_limit_cycle_by_sim(var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False)[source]

Plot limit cycles by the simulation results.

This function help users plot the limit cycles through the simulation results, in which the periodic signals will be automatically found and then treated them as the candidate of limit cycles.

Parameters
  • var (str) – The target variable to found its limit cycles.

  • duration (int, float, tuple, list) – The simulation duration.

  • inputs (tuple, list) – The simulation inputs.

  • plot_style (dict) – The limit cycle plotting style settings.

  • tol (float) – The tolerance to found periodic signals.

  • show (bool) – Whether show the figure.

class brainpy.analysis.symbolic.FastSlowBifurcation(integrals, fast_vars, slow_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]

Fast slow analysis analysis proposed by John Rinzel 1 2 3.

(J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.

Examples

Parameters
  • integrals (function, functions) – The integral functions defined with brainpy.odeint or brainpy.sdeint or brainpy.ddeint, or brainpy.fdeint.

  • fast_vars (dict) – The fast dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.

  • slow_vars (dict) – The slow dynamical variables. It must a dictionary which specifies the boundary of the variables: {‘var1’: [min, max]}.

  • fixed_vars (dict) – The fixed variables. It must a fixed value with the format of {‘var1’: value}.

  • pars_update (dict, optional) – The parameters to update. Or, they can be treated as staitic parameters. Same with the fixed_vars, they are must fixed values with the format of {‘par1’: value}.

  • numerical_resolution (float, dict) – The resolution for numerical iterative solvers. Default is 0.1. It can set the numerical resolution of dynamical variables or dynamical parameters. For example, set numerical_resolution=0.1 will generalize it to all variables and parameters; set numerical_resolution={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05} will specify the particular resolutions to variables and parameters. Moreover, you can also set numerical_resolution={var1: np.array([...]), var2: 0.1} to specify the search points need to explore for variable var1. This will be useful to set sense search points at some inflection points.

  • options (dict, optional) –

    The other setting parameters, which includes:

    perturbation

    float. The small perturbation used to solve the function derivatives.

    sympy_solver_timeout

    float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.

    escape_sympy_solver

    bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.

    lim_scale

    float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to [var_min * (1-lim_scale)/2, var_max * (var_max-1)/2].

References

1(1,2)

Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.

2(1,2)

Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.

3(1,2)

Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.

plot_bifurcation(*args, **kwargs)[source]

Plot bifurcation.

Parameters

show (bool) – Whether show the bifurcation figure.

Returns

points – The bifurcation points which specifies their fixed points and corresponding stability.

Return type

dict

plot_limit_cycle_by_sim(*args, **kwargs)[source]

Plot limit cycles by the simulation results.

This function help users plot the limit cycles through the simulation results, in which the periodic signals will be automatically found and then treated them as the candidate of limit cycles.

Parameters
  • var (str) – The target variable to found its limit cycles.

  • duration (int, float, tuple, list) – The simulation duration.

  • inputs (tuple, list) – The simulation inputs.

  • plot_style (dict) – The limit cycle plotting style settings.

  • tol (float) – The tolerance to found periodic signals.

  • show (bool) – Whether show the figure.

plot_trajectory(*args, **kwargs)[source]

Plot trajectory.

This function helps users to plot specific trajectories.

Parameters
  • initials (list, tuple) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example, (a, b)). It can also be a tuple/list of tuple to specify multiple initial values (for example, [(a1, b1), (a2, b2)]).

  • duration (int, float, tuple, list) – The running duration. Same with the duration in NeuGroup.run(). It can be a int/float (t_end) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]) to specify the specific start and end simulation time for each initial value.

  • plot_duration (tuple/list of tuple, optional) – The duration to plot. It can be a tuple with (start, end). It can also be a list of tuple [(start1, end1), (start2, end2)] to specify the plot duration for each initial value running.

  • show (bool) – Whether show or not.

class brainpy.analysis.symbolic.PhasePlane(model, target_vars, fixed_vars=None, pars_update=None, numerical_resolution=0.1, options=None)[source]

A tool class for phase plane analysis.

PhasePlane is used to analyze the phase portrait of 1D or 2D dynamical systems. It can also be used to analyze the phase portrait of high-dimensional system but with the fixation of other variables to preserve only one/two variables dynamical.

Examples

Parameters
  • model (DynamicalSystem, Integrator, list of Integrator, tuple of Integrator) – The neuron model which defines the differential equations.

  • target_vars (dict) – The target variables to analyze, with the format of {‘var1’: [var_min, var_max], ‘var2’: [var_min, var_max]}.

  • fixed_vars (dict, optional) – The fixed variables, which means the variables will not be updated.

  • pars_update (dict, optional) – The parameters in the differential equations to update.

  • numerical_resolution (float, dict, optional) –

    The variable resolution for numerical iterative solvers. This variable will be useful in the solving of nullcline and fixed points by using the iterative optimization method.

    • It can be a float, which will be used as numpy.arange(var_min, var_max, resolution).

    • Or, it can be a dict, with the format of {'var1': resolution1, 'var2': resolution2}.

    • Or, it can be a dict with the format of {'var1': np.arange(x, x, x), 'var2': np.arange(x, x, x)}.

  • options (dict, optional) –

    The other setting parameters, which includes:

    • lim_scale: float. The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to [var_min * (1-lim_scale)/2, var_max * (var_max-1)/2].

    • sympy_solver_timeout: float, with the unit of second. The maximum time allowed to use sympy solver to get the variable relationship.

    • escape_sympy_solver: bool. Whether escape to use sympy solver, and directly use numerical optimization method to solve the nullcline and fixed points.

    • shgo_args: dict. Arguments of shgo optimization method, which can be used to set the fields of: constraints, n, iters, callback, minimizer_kwargs, options, sampling_method.

    • show_shgo: bool. whether print the shgo’s value.

    • perturbation: float. The small perturbation used to solve the function derivative.

    • fl_tol: float. The tolerance of the function value to recognize it as a candidate of function root point.

    • xl_tol: float. The tolerance of the l2 norm distances between this point and previous points. If the norm distances are all bigger than xl_tol means this point belong to a new function root point.

plot_fixed_point(*args, **kwargs)[source]

Plot fixed points.

plot_limit_cycle_by_sim(initials, duration, tol=0.001, show=False)[source]

Plot limit cycles according to the settings.

Parameters
  • initials (list, tuple) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example, (a, b)). It can also be a tuple/list of tuple to specify multiple initial values (for example, [(a1, b1), (a2, b2)]).

  • duration (int, float, tuple, list) – The running duration. Same with the duration in NeuGroup.run(). It can be a int/float (t_end) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]) to specify the specific start and end simulation time for each initial value.

  • show (bool) – Whether show or not.

plot_nullcline(*args, **kwargs)[source]

Plot nullcline (only supported in 2D system).

plot_trajectory(initials, duration, plot_duration=None, axes='v-v', show=False)[source]

Plot trajectories according to the settings.

Parameters
  • initials (list, tuple, dict) – The initial value setting of the targets. It can be a tuple/list of floats to specify each value of dynamical variables (for example, (a, b)). It can also be a tuple/list of tuple to specify multiple initial values (for example, [(a1, b1), (a2, b2)]).

  • duration (int, float, tuple, list) – The running duration. Same with the duration in NeuGroup.run(). It can be a int/float (t_end) to specify the same running end time, or it can be a tuple/list of int/float ((t_start, t_end)) to specify the start and end simulation time. Or, it can be a list of tuple ([(t1_start, t1_end), (t2_start, t2_end)]) to specify the specific start and end simulation time for each initial value.

  • plot_duration (tuple, list, optional) – The duration to plot. It can be a tuple with (start, end). It can also be a list of tuple [(start1, end1), (start2, end2)] to specify the plot duration for each initial value running.

  • axes (str) –

    The axes to plot. It can be:

    • ’v-v’

      Plot the trajectory in the ‘x_var’-‘y_var’ axis.

    • ’t-v’

      Plot the trajectory in the ‘time’-‘var’ axis.

  • show (bool) – Whether show or not.

plot_vector_field(*args, **kwargs)[source]

Plot vector filed of a 2D/1D system.

Dynamics Analysis (Numeric)

FixedPointFinder(f_cell[, f_type, ...])

Find fixed points by numerical optimization.

class brainpy.analysis.numeric.FixedPointFinder(f_cell, f_type='df', f_loss_batch=None, verbose=True, num_opt_batch=100, num_opt_max=10000, opt_setting=None, noise=0.0, tol_opt=1e-05, tol_speed=1e-05, tol_unique=0.025, tol_outlier=1.0)[source]

Find fixed points by numerical optimization.

Parameters
  • f_cell (callable, function) – The function to compute the recurrent units.

  • f_type (str) –

    The system’s type: continuous system or discrete system.

    • ’df’: continuous derivative function, denotes this is a continuous system, or

    • ’F’: discrete update function, denotes this is a discrete system.

  • f_loss_batch (callable, function) – The function to compute the loss.

  • verbose (bool) – Whether print the optimization progress.

  • num_opt_max (int) – The maximum number of optimization.

  • num_opt_batch (int) – Print training information during optimization every so often.

  • noise (float) – Gaussian noise added to fixed point candidates before optimization.

  • tol_opt (float) – Stop optimizing when the average value of the batch is below this value.

  • tol_speed (float) – Discard fixed points with squared speed larger than this value.

  • tol_unique (float) – Tolerance for determination of identical fixed points.

  • tol_outlier (float) – Any point whose closest fixed point is greater than tol is an outlier.

compute_jacobians(points)[source]

Compute the jacobian matrices at the points.

Parameters

points (np.ndarray, bm.JaxArray, jax.ndarray) – The fixed points with the shape of (num_point, num_dim).

Returns

jacobians – npoints number of jacobians, np array with shape npoints x dim x dim

Return type

bm.JaxArray

decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=True)[source]

Compute the eigenvalues of the matrices.

Parameters
  • matrices (np.ndarray, bm.JaxArray, jax.ndarray) – A 3D array with the shape of (num_matrices, dim, dim).

  • do_compute_lefts (bool) – Compute the left eigenvectors? Requires a pseudo-inverse call.

Returns

decompositions – A list of dictionaries with sorted eigenvalues components: (eigenvalues, right eigenvectors, and left eigenvectors).

Return type

list

exclude_outliers(fixed_points, metric='euclidean')[source]

Exclude points whose closest neighbor is further than threshold.

Parameters
  • fixed_points (np.ndarray) – The fixed points with the shape of (num_point, num_dim).

  • metric (str) – The distance metric passed to scipy.spatial.pdist. Defaults to “euclidean”

Returns

fps_and_ids – A 2-tuple of (kept fixed points, ids of kept fixed points).

Return type

tuple

find_fixed_points(candidates)[source]

Top-level routine to find fixed points, keeping only valid fixed points.

This function will:

  1. Add noise to the fixed point candidates

  2. Optimize to find the closest fixed points / slow points

  3. Exclude any fixed points whose fixed point loss is above threshold (‘fp_tol’)

  4. Exclude any non-unique fixed points according to a tolerance (‘unique_tol’)

  5. Exclude any far-away “outlier” fixed points (‘outlier_tol’)

Parameters
  • rnn_fun – one-step update function as a function of hidden state

  • candidates – ndarray with shape npoints x ndims

  • hyper_params – dict of hyper parameters for fp optimization, including tolerances related to keeping fixed points

Returns

4-tuple of (kept fixed points sorted with slowest points first,

fixed point losses, indicies of kept fixed points, details of optimization)

keep_unique(fixed_points)[source]

Filter unique fixed points by choosing a representative within tolerance.

Parameters

fixed_points (np.ndarray) – The fixed points with the shape of (num_point, num_dim).

Returns

fps_and_ids – A 2-tuple of (kept fixed points, ids of kept fixed points).

Return type

tuple

optimize_fixed_points(candidates)[source]

Find fixed points via optimization.

Parameters

candidates (jax.ndarray, JaxArray) – The array with the shape of (batch size, state dim) of hidden states of RNN to start training for fixed points.

Returns

fps_and_losses – A tuple of (the fixed points, the optimization losses).

Return type

tuple

speed_tolerance_filter(fixed_points)[source]

Filter fixed points whose speed larger than a given tolerance.

Parameters

fixed_points (np.ndarray) – The ndarray with shape of (num_point, num_dim).

Returns

fps_and_ids – A 2-tuple of (kept fixed points, ids of kept fixed points).

Return type

tuple

Continuation Analysis

Stability Analysis

get_1d_stability_types()

Get the stability types of 1D system.

get_2d_stability_types()

Get the stability types of 2D system.

get_3d_stability_types()

Get the stability types of 3D system.

stability_analysis(derivatives)

Stability analysis for fixed points.

brainpy.visualization module

Visualization toolkit.

get_figure(row_num, col_num[, row_len, col_len])

Get the constrained_layout figure.

line_plot(ts, val_matrix[, plot_ids, ax, ...])

Show the specified value in the given object (Neurons or Synapses.)

raster_plot(ts, sp_matrix[, ax, marker, ...])

Show the rater plot of the spikes.

animate_2D(values, net_size[, dt, val_min, ...])

Animate the potentials of the neuron group.

animate_1D(dynamical_vars[, static_vars, ...])

Animation of one-dimensional data.

plot_style1([fontsize, axes_edgecolor, ...])

Plot style for publication.

brainpy.tools module

AST-to-Code

ast2code(ast_node[, indent, line_length])

Decompiles an AST into Python code.

Code Tools

copy_doc(source_f)

code_lines_to_func(lines, func_name, ...[, ...])

get_identifiers(expr[, include_numbers])

Return all the identifiers in a given string expr, that is everything that matches a programming language variable like expression, which is here implemented as the regexp \b[A-Za-z_][A-Za-z0-9_]*\b.

indent(text[, num_tabs, spaces_per_tab, tab])

deindent(text[, num_tabs, spaces_per_tab, ...])

word_replace(expr, substitutions[, exclude_dot])

Applies a dict of word substitutions.

is_lambda_function(func)

Check whether the function is a lambda function.

get_main_code(func[, codes])

Get the main function _code string.

get_func_source(func)

change_func_name(f, name)

New Dict

DictPlus(*args, **kwargs)

Python dictionaries with advanced dot notation access.

class brainpy.tools.DictPlus(*args, **kwargs)[source]

Python dictionaries with advanced dot notation access.

For example:

>>> d = DictPlus({'a': 10, 'b': 20})
>>> d.a
10
>>> d['a']
10
>>> d.c  # this will raise a KeyError
KeyError: 'c'
>>> d.c = 30  # but you can assign a value to a non-existing item
>>> d.c
30
copy() a shallow copy of D[source]
setdefault(key, default=None)[source]

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

update([E, ]**F) None.  Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

Name Checking

check_name(name, obj)

get_name(type)

Other Tools

size2num(size)

brainpy.jaxsetting module

enable_x64(mode)

set_platform(platform)

Changes platform to CPU, GPU, or TPU.

set_host_device_count(n)

By default, XLA considers all CPU cores as one device.

Release notes

Version 1.1.5

API changes:

  • fix bugs on ndarray import in brainpy.base.function.py

  • convenient ‘get_param’ interface brainpy.simulation.layers

  • add more weight initialization methods

Doc changes:

  • add more examples in README

Version 1.1.4

API changes:

  • add .struct_run() in DynamicalSystem

  • add numpy_array() conversion in brainpy.math.utils module

  • add Adagrad, Adadelta, RMSProp optimizers

  • remove setting methods in brainpy.math.jax module

  • remove import jax in brainpy.__init__.py and enable jax setting, including

    • enable_x64()

    • set_platform()

    • set_host_device_count()

  • enable b=None as no bias in brainpy.simulation.layers

  • set int_ and float_ as default 32 bits

  • remove dtype setting in Initializer constructor

Doc changes:

  • add optimizer in “Math Foundation”

  • add dynamics training docs

  • improve others

Version 1.1.3

  • fix bugs of JAX parallel API imports

  • fix bugs of post_slice structure construction

  • update docs

Version 1.1.2

  • add pre2syn and syn2post operators

  • add verbose and check option to Base.load_states()

  • fix bugs on JIT DynamicalSystem (numpy backend)

Version 1.1.1

  • fix bugs on symbolic analysis: model trajectory

  • change absolute access in the variable saving and loading to the relative access

  • add UnexpectedTracerError hints in JAX transformation functions

Version 1.1.0

This package releases a new version of BrainPy.

Highlights of core changes:

math module

  • support numpy backend

  • support JAX backend

  • support jit, vmap and pmap on class objects on JAX backend

  • support grad, jacobian, hessian on class objects on JAX backend

  • support make_loop, make_while, and make_cond on JAX backend

  • support jit (based on numba) on class objects on numpy backend

  • unified numpy-like ndarray operation APIs

  • numpy-like random sampling APIs

  • FFT functions

  • gradient descent optimizers

  • activation functions

  • loss function

  • backend settings

base module

  • Base for whole Version ecosystem

  • Function to wrap functions

  • Collector and TensorCollector to collect variables, integrators, nodes and others

integrators module

  • class integrators for ODE numerical methods

  • class integrators for SDE numerical methods

simulation module

  • support modular and composable programming

  • support multi-scale modeling

  • support large-scale modeling

  • support simulation on GPUs

  • fix bugs on firing_rate()

  • remove _i in update() function, replace _i with _dt, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)

  • reimplement the input_step and monitor_step in a more intuitive way

  • support to set dt in the single object level (i.e., single instance of DynamicSystem)

  • common used DNN layers

  • weight initializations

  • refine synaptic connections

Version 1.0.3

Fix bugs on

  • firing rate measurement

  • stability analysis

Version 1.0.2

This release continues to improve the user-friendliness.

Highlights of core changes:

  • Remove support for Numba-CUDA backend

  • Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).

  • Add the output message of the step function running error.

  • More powerful support for Monitoring

  • More powerful support for running order scheduling

  • Remove unsqueeze() and squeeze() operations in brainpy.ops

  • Add reshape() operation in brainpy.ops

  • Improve docs for numerical solvers

  • Improve tests for numerical solvers

  • Add keywords checking in ODE numerical solvers

  • Add more unified operations in brainpy.ops

  • Support “@every” in steps and monitor functions

  • Fix ODE solver bugs for class bounded function

  • Add build phase in Monitor

Version 1.0.1

  • Fix bugs

Version 1.0.0

  • NEW VERSION OF BRAINPY

  • Change the coding style into the object-oriented programming

  • Systematically improve the documentation

Version 0.3.5

  • Add ‘timeout’ in sympy solver in neuron dynamics analysis

  • Reconstruct and generalize phase plane analysis

  • Generalize the repeat mode of Network to different running duration between two runs

  • Update benchmarks

  • Update detailed documentation

Version 0.3.1

  • Add a more flexible way for NeuState/SynState initialization

  • Fix bugs of “is_multi_return”

  • Add “hand_overs”, “requires” and “satisfies”.

  • Update documentation

  • Auto-transform range to numba.prange

  • Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models

Version 0.3.0

Computation API

  • Rename “brainpy.numpy” to “brainpy.backend”

  • Delete “pytorch”, “tensorflow” backends

  • Add “numba” requirement

  • Add GPU support

Profile setting

  • Delete “backend” profile setting, add “jit”

Core systems

  • Delete “autopepe8” requirement

  • Delete the format code prefix

  • Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”

  • Change the “ST” declaration out of “requires”

  • Add “repeat” mode run in Network

  • Change “vector-based” to “mode” in NeuType and SynType definition

Package installation

  • Remove “pypi” installation, installation now only rely on “conda”

Version 0.2.4

API changes

  • Fix bugs

Version 0.2.3

API changes

  • Add “animate_1D” in visualization module

  • Add “PoissonInput”, “SpikeTimeInput” and “FreqInput” in inputs module

  • Update phase_portrait_analyzer.py

Models and examples

  • Add CANN examples

Version 0.2.2

API changes

  • Redesign visualization

  • Redesign connectivity

  • Update docs

Version 0.2.1

API changes

  • Fix bugs in numba import

  • Fix bugs in numpy mode with scalar model

Version 0.2.0

API changes

  • For computation: numpy, numba

  • For model definition: NeuType, SynConn

  • For model running: Network, NeuGroup, SynConn, Runner

  • For numerical integration: integrate, Integrator, DiffEquation

  • For connectivity: One2One, All2All, GridFour, grid_four, GridEight, grid_eight, GridN, FixedPostNum, FixedPreNum, FixedProb, GaussianProb, GaussianWeight, DOG

  • For visualization: plot_value, plot_potential, plot_raster, animation_potential

  • For measurement: cross_correlation, voltage_fluctuation, raster_plot, firing_rate

  • For inputs: constant_current, spike_current, ramp_current.

Models and examples

  • Neuron models: HH model, LIF model, Izhikevich model

  • Synapse models: AMPA, GABA, NMDA, STP, GapJunction

  • Network models: gamma oscillation

Indices and tables