BrainPy documentation#
BrainPy is a highly flexible and extensible framework targeting on the general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:
JIT compilation and automatic differentiation for class objects.
Numerical methods for ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), fractional differential equations (FDEs), etc.
Dynamics building with the modular and composable programming interface.
Dynamics simulation for various brain objects with parallel supports.
Dynamics training with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.
Dynamics analysis for low- and high-dimensional systems, including phase plane analysis, bifurcation analysis, linearization analysis, and fixed/slow point finding.
And more others ……
Installation#
BrainPy
is designed to run cross platforms, including Windows,
GNU/Linux, and OSX. It only relies on Python libraries.
Installation with pip#
You can install BrainPy
from the pypi.
To do so, use:
pip install brainpy
To update the BrainPy version, you can use
pip install -U brainpy
If you want to install the pre-release version (the latest development version) of BrainPy, you can use:
pip install --pre brainpy
Installation from source#
If you decide not to use pip
, you can install BrainPy
from
GitHub,
or OpenI.
To do so, use:
pip install git+https://github.com/PKU-NIP-Lab/BrainPy
# or
pip install git+https://git.openi.org.cn/OpenI/BrainPy
Dependency 1: NumPy#
In order to make BrainPy work normally, users should install several dependent Python packages.
The basic function of BrainPy
only relies on NumPy, which is very
easy to install through pip
or conda
:
pip install numpy
# or
conda install numpy
Dependency 2: JAX#
BrainPy relies on JAX. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend.
Linux & MacOS#
Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of jax and jaxlib for Linux and macOS systems are available at
for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html
for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
If you want to install a CPU-only version of jax and jaxlib, you can run
pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Alternatively, you can download the preferred release “.whl” file for jaxlib
from the above release links, and install it via pip
:
pip install xxx-0.3.14-xxx.whl
pip install jax==0.3.14
Note
Note that the versions of jaxlib and jax should be consistent.
For example, if you are using jax==0.3.14, you would better install jax==0.3.14.
Windows#
For Windows users, jax and jaxlib can be installed from the community supports. Specifically, you can install jax and jaxlib through:
pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html
If you are using GPU, you can install GPU-versioned wheels through:
pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html
Alternatively, you can manually install you favourite version of jax and jaxlib by
downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html .
Then install it via pip
:
pip install xxx-0.3.14-xxx.whl
pip install jax==0.3.14
WSL#
Moreover, for Windows 10+ system, we recommend using Windows Subsystem for Linux (WSL). The installation guide can be found in WSL Installation Guide for Windows 10/11. Then, you can install JAX in WSL just like the installation step in Linux/MacOs.
Dependency 3: brainpylib#
Many customized operators in BrainPy are implemented in brainpylib
.
brainpylib
can also be installed through pypi.
pip install brainpylib
For GPU operators, you should compile brainpylib
from source.
The details please see
Compile GPU operators in brainpylib.
Other Dependency#
In order to get full supports of BrainPy, we recommend you install the following packages:
Numba: needed in some NumPy-based computations
pip install numba
# or
conda install numba
matplotlib: required in some visualization functions, but now it is recommended that users explicitly import matplotlib for visualization
pip install matplotlib
# or
conda install matplotlib
Simulating a Brain Dynamics Model#
One of the most important approaches of studying brain dynamics is building a dynamic model and doing simulation. Generally, there are two ways to construct a dynamic model. The first one is called spiking models, which attempt to finely simulate the activity of each neuron in the target population. They are named spiking models because the simulation process records the precise timing of spiking of every neuron. The second is called rate models, which regard a population of neurons with similar properties as a single firing unit and examine the firing rate of this population. In this section, we will illustrate how to build and simulate a spiking neural network, e.i. SNN.
To begin with, the BrainPy package should be imported:
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.8'
Simulating an E-I balanced network#
Building an E-I balanced network#
Firstly, let’s try to build an E-I balanced network. It was proposed to interpret the irregular firing of neurons in the cortical area [1]. Since the structure of an E-I balanced network is relatively simple, it is a good practice that helps users to learn the basic paradigm of brain dynamic simulation in BrainPy. The structure of a E-I balanced network is as follows:

An E-I balanced network is composed of two neuron groups and the synaptic connections between them. Specifically, they include:
a group of excitatory neurons, \(\mathrm{E}\),
a group of inhibitory neurons, \(\mathrm{I}\),
synaptic connections within the excitatory and inhibitory neuron groups, respectively, and
the inter-connections between these two groups.
To construct the network, we need to define these components one by one. BrainPy provides plenty of handy built-in models for brain dynamic simulation. They are contained in brainpy.dyn
. Let’s choose the simplest yet the most canonical neuron model, the Leaky Integrate-and-Fire (LIF) model, to build the excitatory and inhibitory neuron groups:
E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60.,
tau=20., tau_ref=5., method='exp_auto',
V_initializer=bp.init.Normal(-60., 2.))
I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60.,
tau=20., tau_ref=5., method='exp_auto',
V_initializer=bp.init.Normal(-60., 2.))
When defining the LIF neuron group, the parameters can be tuned according to users’ need. The first parameter denotes the number of neurons. Here the ratio of excitatory and inhibitory neurons is set to 4:1. V_rest
denotes the resting potential, V_th
denotes the firing threshold, V_reset
denotes the reset value after firing, tau
is the time constant, and tau_ref
is the duration of the refractory period. method
refers to the numerical integration method to be used in simulation.
Then the synaptic connections between these two groups can be defined as follows:
E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6,
tau=5., output=bp.synouts.COBA(E=0.),
method='exp_auto')
E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6,
tau=5., output=bp.synouts.COBA(E=0.),
method='exp_auto')
I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7,
tau=10., output=bp.synouts.COBA(E=-80.),
method='exp_auto')
I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7,
tau=10., output=bp.synouts.COBA(E=-80.),
method='exp_auto')
Here we use the Exponential synapse model (bp.synapses.Exponential
) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use bp.conn.FixedProb
, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in Synaptic Connection). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model.
After defining all the components, they can be combined to form a network:
net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I)
In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (E
and I
) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).
We have successfully constructed an E-I balanced network by using BrainPy’s biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, brainpy.dyn.Network()
is a simple example of customizing a network model. Please refer to Dynamic Simulation for more information.
Running a simulation#
After building a SNN, we can use it for dynamic simulation. To run a simulation, we need to wrap the network model into a runner first. BrainPy provides DSRunner
in brainpy.dyn
, which will be expanded in the Runners tutorial. Users can initialize DSRunner
as followed:
runner = bp.DSRunner(net,
monitors=['E.spike', 'I.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)],
dt=0.1)
To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the spike
variable in the E
and I
LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration dt
(with the default value of 0.1) can also be specified.
More details of how to give inputs and monitors please refer to Dynamic Simulation.
After creating the runner, we can run a simulation by calling the runner:
runner.run(100)
where the calling function receives the simulation time (usually in milliseconds) as the input. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to Just-In-Time Compilation for more details.
The simulation results are stored as NumPy arrays in the monitors, and can be visualized easily:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4.5))
plt.subplot(121)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=False)
plt.subplot(122)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], show=True)

In the code above, brianpy.visualize
contains some useful functions to visualize simulation results based on the matplotlib
package. Since the simulation results are stored as NumPy arrays, users can directly use matplotlib
for visualization.
Simulating a decision-making network#
Building a decision-making network#
After learning how to build a E-I balanced network, we can try to handle a more complex model. In 2002, Wang proposed a decision-making model that could choose between two conflict inputs by accumulating evidence over time [2].
The structure of a decision-making network is as follows. Similar to the E-I balanced network, the decision-making network contains an excitatory and an inhibitory neuron group, forming connections within each group and between each other. What is different is that there are two specific subpopulation of neurons, A and B, that receive conflict inputs from outside (other brain areas). After given the external inputs, if the activity of A prevail over B, it means the network chooses option A, and vice versa.

To construct a decision-making network, we should build all neuron groups:
Two excitatory neuron groups with different selectivity, \(\mathrm{A}\) and \(\mathrm{B}\), and other excitatory neurons, \(\mathrm{N}\);
An inhibitory neuron group, \(\mathrm{I}\);
Neurons generating external inputs \(\mathrm{I_A}\) and \(\mathrm{I_B}\);
Neurons generating noise to all neuron groups, \(\mathrm{noise_A}\), \(\mathrm{noise_B}\), \(\mathrm{noise_N}\), and \(\mathrm{noise_I}\).
And the synapse connection between them:
Connection from excitatory neurons to others, \(\mathrm{A2A}\), \(\mathrm{A2B}\), \(\mathrm{A2N}\), \(\mathrm{A2I}\), \(\mathrm{B2A}\), \(\mathrm{B2B}\), \(\mathrm{B2N}\), \(\mathrm{B2I}\), and \(\mathrm{N2A}\), \(\mathrm{N2B}\), \(\mathrm{N2N}\), \(\mathrm{N2I}\);
Connection from inhibitory neurons to others, \(\mathrm{I2A}\), \(\mathrm{I2B}\), \(\mathrm{I2N}\), \(\mathrm{I2I}\);
Connection from external inputs to selective neuron groups, \(\mathrm{IA2A}\), \(\mathrm{IB2B}\);
Connection from noise neurons to excitatory and inhibitory neurons, \(\mathrm{noise2A}\), \(\mathrm{noise2B}\), \(\mathrm{noise2N}\), \(\mathrm{noise2I}\).
Now let’s build these neuron groups and connections.
First of all, to imitate the biophysical experiments, we define three periods:
pre_stimulus_period = 100. # time before the external simuli are given
stimulus_period = 1000. # time within which the external simuli are given
delay_period = 500. # time after the external simuli are removed
total_period = pre_stimulus_period + stimulus_period + delay_period
To build \(\mathrm{I_A}\) and \(\mathrm{I_B}\), we shall define a class of neuron groups that can generate stochastic Possion stimulu. To define neuron groups, they should inherit from brainpy.dyn.NeuGroup
.
class PoissonStim(bp.NeuGroup):
def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs):
super(PoissonStim, self).__init__(size=size, **kwargs)
# initialize parameters
self.freq_mean = freq_mean
self.freq_var = freq_var
self.t_interval = t_interval
# initialize variables
self.freq = bm.Variable(bm.zeros(1))
self.freq_t_last_change = bm.Variable(bm.ones(1) * -1e7)
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.rng = bm.random.RandomState()
def update(self, tdi):
in_interval = bm.logical_and(pre_stimulus_period < tdi.t, tdi.t < pre_stimulus_period + stimulus_period)
freq = bm.where(in_interval, self.freq[0], 0.)
change = bm.logical_and(in_interval, (tdi.t - self.freq_t_last_change[0]) >= self.t_interval)
self.freq[:] = bm.where(change, self.rng.normal(self.freq_mean, self.freq_var), freq)
self.freq_t_last_change[:] = bm.where(change, tdi.t, self.freq_t_last_change[0])
self.spike.value = self.rng.random(self.num) < self.freq[0] * tdi.dt / 1000.
Because there are too many neuron groups and connections, it will be much clearer if we define a new network class inheriting brainpy.dyn.Network
to accommodate all these neurons and synapses:
class DecisionMaking(bp.Network):
def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, dt=bm.get_dt()):
super(DecisionMaking, self).__init__()
# initialize neuron-group parameters
num_exc = int(1600 * scale)
num_inh = int(400 * scale)
num_A = int(f * num_exc)
num_B = int(f * num_exc)
num_N = num_exc - num_A - num_B
poisson_freq = 2400. # Hz
# initialize synapse parameters
w_pos = 1.7
w_neg = 1. - f * (w_pos - 1.) / (1. - f)
g_ext2E_AMPA = 2.1 # nS
g_ext2I_AMPA = 1.62 # nS
g_E2E_AMPA = 0.05 / scale # nS
g_E2I_AMPA = 0.04 / scale # nS
g_E2E_NMDA = 0.165 / scale # nS
g_E2I_NMDA = 0.13 / scale # nS
g_I2E_GABAa = 1.3 / scale # nS
g_I2I_GABAa = 1.0 / scale # nS
# parameters of the AMPA synapse
ampa_par = dict(delay_step=int(0.5 / dt), tau=2.0, output=bp.synouts.COBA(E=0.))
# parameters of the GABA synapse
gaba_par = dict(delay_step=int(0.5 / dt), tau=5.0, output=bp.synouts.COBA(E=-70.))
# parameters of the NMDA synapse
nmda_par = dict(delay_step=int(0.5 / dt), tau_decay=100, tau_rise=2.,
a=0.5, output=bp.synouts.MgBlock(E=0., cc_Mg=1.))
# excitatory and inhibitory neuron groups, A, B, N, and I
A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
tau_ref=2., V_initializer=bp.init.OneInit(-70.))
B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
tau_ref=2., V_initializer=bp.init.OneInit(-70.))
N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
tau_ref=2., V_initializer=bp.init.OneInit(-70.))
I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05,
tau_ref=1., V_initializer=bp.init.OneInit(-70.))
# neurons generating external inputs, I_A and I_B
IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence)
IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence)
# noise neurons
self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq)
self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq)
self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq)
self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq)
# connection from excitatory neurons to others
self.N2B_AMPA = bp.synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
self.N2A_AMPA = bp.synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
self.N2N_AMPA = bp.synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
self.N2I_AMPA = bp.synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
self.N2B_NMDA = bp.synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
self.N2A_NMDA = bp.synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
self.N2N_NMDA = bp.synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
self.N2I_NMDA = bp.synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)
self.B2B_AMPA = bp.synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par)
self.B2A_AMPA = bp.synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
self.B2N_AMPA = bp.synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
self.B2I_AMPA = bp.synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
self.B2B_NMDA = bp.synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par)
self.B2A_NMDA = bp.synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
self.B2N_NMDA = bp.synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
self.B2I_NMDA = bp.synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)
self.A2B_AMPA = bp.synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
self.A2A_AMPA = bp.synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par)
self.A2N_AMPA = bp.synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
self.A2I_AMPA = bp.synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
self.A2B_NMDA = bp.synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
self.A2A_NMDA = bp.synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par)
self.A2N_NMDA = bp.synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
self.A2I_NMDA = bp.synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)
# connection from inhibitory neurons to others
self.I2B = bp.synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
self.I2A = bp.synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
self.I2N = bp.synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
self.I2I = bp.synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, **gaba_par)
# connection from external inputs to selective neuron groups
self.IA2A = bp.synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
self.IB2B = bp.synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
# connectioni from noise neurons to excitatory and inhibitory neurons
self.noise2B = bp.synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
self.noise2A = bp.synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
self.noise2N = bp.synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
self.noise2I = bp.synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, **ampa_par)
# add A, B, I, N to the class
self.A = A
self.B = B
self.N = N
self.I = I
self.IA = IA
self.IB = IB
Though the code seems longer than the E-I balanced network, the basic building paradigm is the same: building neuron groups and the connections among them.
Running a simulation#
After building it, the simulation process will be much the same as running a E-I balanced network. First we should wrap the network into a runner:
net = DecisionMaking(scale=1., coherence=25.6, mu0=40.)
runner = bp.DSRunner(net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'])
Then we call the runner to run the simulation:
runner.run(total_period)
Finally, we visualize the simulation result by using matplotlib
:
fig, gs = plt.subplots(4, 1, figsize=(10, 12), sharex='all')
t_start = 0.
# the raster plot of A
fig.add_subplot(gs[0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['A.spike'], markersize=1)
plt.title("Spiking activity of group A")
plt.ylabel("Neuron Index")
# the raster plot of A
fig.add_subplot(gs[1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['B.spike'], markersize=1)
plt.title("Spiking activity of group B")
plt.ylabel("Neuron Index")
# the firing rate of A and B
fig.add_subplot(gs[2])
rateA = bp.measure.firing_rate(runner.mon['A.spike'], width=10.)
rateB = bp.measure.firing_rate(runner.mon['B.spike'], width=10.)
plt.plot(runner.mon.ts, rateA, label="Group A")
plt.plot(runner.mon.ts, rateB, label="Group B")
plt.ylabel('Firing rate [Hz]')
plt.title("Population activity")
plt.legend()
# the external stimuli
fig.add_subplot(gs[3])
plt.plot(runner.mon.ts, runner.mon['IA.freq'], label="group A")
plt.plot(runner.mon.ts, runner.mon['IB.freq'], label="group B")
plt.title("Input activity")
plt.ylabel("Firing rate [Hz]")
plt.legend()
for i in range(4):
gs[i].axvline(pre_stimulus_period, linestyle='dashed', color=u'#444444')
gs[i].axvline(pre_stimulus_period + stimulus_period, linestyle='dashed', color=u'#444444')
plt.xlim(t_start, total_period + 1)
plt.xlabel("Time [ms]")
plt.tight_layout()
plt.show()

For more information about brain dynamic simulation, please refer to Dynamics Simulation in the BDP tutorial.
Simulating a firing rate-based network#
Neural mass model#
A neural mass models is a low-dimensional population model of spiking neural networks. It aims to describe the coarse grained activity of large populations of neurons and synapses. Mathematically, it is a dynamical system of non-linear ODEs. A classical neural mass model is the two-dimensional Wilson–Cowan model. This model tracks the activity of an excitatory population of neurons coupled to an inhibitory population. With the augmentation of such models by more realistic forms of synaptic and network interaction they have proved especially successful in providing fits to neuro-imaging data.
Here, let’s try the Wilson-Cowan model.
wc = bp.rates.WilsonCowanModel(2,
wEE=16., wIE=15., wEI=12., wII=3.,
E_a=1.5, I_a=1.5, E_theta=3., I_theta=3.,
method='exp_euler_auto',
x_initializer=bm.asarray([-0.2, 1.]),
y_initializer=bm.asarray([0.0, 1.]))
runner = bp.DSRunner(wc, monitors=['x', 'y'], inputs=['input', -0.5])
runner.run(10.)
fig, gs = bp.visualize.get_figure(1, 2, 4, 3)
ax = fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='e', ax=ax)
ax = fig.add_subplot(gs[0, 1])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='i', ax=ax, show=True)

We can see this model at least has two stable states.
Bifurcation diagram
With the automatic analysis module in BrainPy, we can easily inspect the bifurcation digram of the model. Bifurcation diagrams can give us an overview of how different parameters of the model affect its dynamics (the details of the automatic analysis support of BrainPy please see the introduction in Analyzing a Dynamical Model and tutorials in Dynamics Analysis). In this case, we make x_ext
as a bifurcation parameter, and try to see how the system behavior changes with the change of x_ext
.
bf = bp.analysis.Bifurcation2D(
wc,
target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
target_pars={'x_ext': [-2, 2]},
pars_update={'y_ext': 0.},
resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 40000 candidates
I am trying to filter out duplicate fixed points ...
Found 579 fixed points.
I am plotting the limit cycle ...


Similarly, simulating and analyzing a rate-based FitzHugh-Nagumo model is also a piece of cake by using BrainPy.
fhn = bp.rates.FHN(1, method='exp_auto')
bf = bp.analysis.Bifurcation2D(
fhn,
target_vars={'x': [-2, 2], 'y': [-2, 2]},
target_pars={'x_ext': [0, 2]},
pars_update={'y_ext': 0.},
resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 20000 candidates
I am trying to filter out duplicate fixed points ...
Found 200 fixed points.
I am plotting the limit cycle ...


In this model, we find that when the external input x_ext
has the value in [0.72, 1.4], the model will generate limit cycles. We can verify this by simulation.
runner = bp.DSRunner(fhn, monitors=['x', 'y'], inputs=['input', 1.0])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x')
bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y', show=True)

Whole-brain model#
A rate-based whole-brain model is a network model which consists of coupled brain regions. Each brain region is represented by a neural mass model which is connected to other brain regions according to the underlying network structure of the brain, also known as the connectome. In order to illustrate how to use BrainPy’s support for whole-brain modeling, here we provide a processed data in the following link:
A processed data from ConnectomeDB of the Human Connectome Project (HCP): https://share.weiyun.com/wkPpARKy
Please download the dataset and place it in your favorite PATH
.
PATH = './data/hcp.npz'
In general, a dataset for whole-brain modeling consists of the following parts:
1. A structural connectivity matrix which captures the synaptic connection strengths between brain areas. It often derived from DTI tractography of the whole brain. The connectome is then typically parcellated in a preferred atlas (for example the AAL2 atlas) and the number of axonal fibers connecting each brain area with every other area is counted. This number serves as an indication of the synaptic coupling strengths between the areas of the brain.
2. A delay matrix which calculated from the average length of the axonal fibers connecting each brain area with another.
3. A set of functional data that can act as a target for model optimization. Resting-state fMRI offers an easy and fairly unbiased way for calibrating whole-brain models. EEG data could be used as well.
Now, let’s load the dataset.
data = bm.load(PATH)
# The structural connectivity matrix
data['Cmat'].shape
(80, 80)
# The fiber length matrix
data['Dmat'].shape
(80, 80)
# The functional data for 7 subjects
data['FCs'].shape
(7, 80, 80)
Let’s have a look what the data looks like.
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'plasma'
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.subplots_adjust(wspace=0.28)
im = axs[0].imshow(data['Cmat'])
axs[0].set_title("Connection matrix")
fig.colorbar(im, ax=axs[0], fraction=0.046, pad=0.04)
im = axs[1].imshow(data['Dmat'], cmap='inferno')
axs[1].set_title("Fiber length matrix")
fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)
im = axs[2].imshow(data['FCs'][0], cmap='inferno')
axs[2].set_title("Empirical FC of subject 1")
fig.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04)
plt.show()

Let’s first get the delay matrix according to the fiber length matrix, the signal transmission speed between areas, and the numerical integration step dt
. Here, we assume the axonal transmission speed is 20 and the simulation time step dt=0.1
ms.
sigal_speed = 20.
# the number of the delay steps
delay_mat = data['Dmat'] / sigal_speed / bm.get_dt()
delay_mat = bm.asarray(delay_mat, dtype=bm.int_)
The connectivity matrix can be directly obtained through the structural connectivity matrix, which times a global coupling strength parameter gc
. b
gc = 1.
conn_mat = bm.asarray(data['Cmat'] * gc)
# It is necessary to exclude the self-connections
bm.fill_diagonal(conn_mat, 0)
We now are ready to instantiate a whole-brain model with the neural mass model and the dataset the processed before.
class WholeBrainNet(bp.Network):
def __init__(self, Cmat, Dmat):
super(WholeBrainNet, self).__init__()
self.fhn = bp.rates.FHN(
80,
x_ou_sigma=0.01,
y_ou_sigma=0.01,
method='exp_auto'
)
self.syn = bp.synapses.DiffusiveCoupling(
self.fhn.x,
self.fhn.x,
var_to_output=self.fhn.input,
conn_mat=Cmat,
delay_steps=Dmat.astype(bm.int_),
initial_delay_data=bp.init.Uniform(0, 0.05)
)
net = WholeBrainNet(conn_mat, delay_mat)
runner = bp.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)
The simulated results can be used to estimate the functional correlation matrix.
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
fc = bp.measure.functional_connectivity(runner.mon['fhn.x'])
ax = axs[0].imshow(fc)
plt.colorbar(ax, ax=axs[0])
axs[1].plot(runner.mon.ts, runner.mon['fhn.x'][:, ::5], alpha=0.8)
plt.tight_layout()
plt.show()

We can compute the element-wise Pearson correlation of the functional connectivity matrices of the simulated data to the empirical data to estimate how well the model captures the inter-areal functional correlations found in empirical resting-state recordings.
scores = [bp.measure.matrix_correlation(fc, fcemp)
for fcemp in data['FCs']]
print("Correlation per subject:", [f"{s:.2}" for s in scores])
print("Mean FC/FC correlation: {:.2f}".format(bm.mean(bm.asarray(scores))))
Correlation per subject: ['0.57', '0.46', '0.56', '0.5', '0.56', '0.48', '0.46']
Mean FC/FC correlation: 0.51
References#
van Vreeswijk, C., & Sompolinsky, H. (1996). Chaos in neuronal networks with balanced excitatory and inhibitory activity. Science (New York, N.Y.), 274(5293), 1724–1726. https://doi.org/10.1126/science.274.5293.1724
Wang X. J. (2002). Probabilistic decision making by slow reverberation in cortical circuits. Neuron, 36(5), 955–968. https://doi.org/10.1016/s0896-6273(02)01092-9
Training a Brain Dynamics Model#
In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy provides various interfaces to help users train dynamical systems.
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
bm.enable_x64()
bm.set_platform('cpu')
bp.__version__
'2.3.8'
import matplotlib.pyplot as plt
Training a reservoir network model#
For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.
(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).
The difference between the two models is illustrated in the following figure.
(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.
Here, let’s implement a next generation reservoir model to predict the chaotic time series, named as Lorenz attractor. Particularly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.
dt = 0.01
data = bd.chaos.LorenzEq(100, dt=dt)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.xs.flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.ys.flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.zs.flatten()))
plt.ylabel('z')
plt.show()

Let’s first create a function to get the data.
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
return res.reshape((1, ) + res.shape)
To accomplish this task, we implement a next-generation reservoir model of 4 delay history information with stride of 5, and their quadratic polynomial monomials, same as (Gauthier, et. al., Nature Communications, 2021).
class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()
self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5)
self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bm.training_mode)
def update(self, sha, x):
# "sha" is the arguments shared across all nodes.
# other arguments like "x" can be customized by users.
return self.o(sha, self.r(sha, x))
with bm.environment(bm.batching_mode):
model = NGRC(num_in=3, num_out=3)
Moreover, we use Ridge Regression method to train the model.
trainer = bp.RidgeTrainer(model, alpha=1e-6)
We warm-up the network with 20 ms.
warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)
# outputs should be an array with the shape of
# (num_batch, num_time, num_out)
outs.shape
(1, 2000, 3)
The training data is the time series from 20 ms to 80 ms. We want the network has the ability to forecast 1 time step ahead.
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)
_ = trainer.fit([x_train, y_train])
Then we test the trained network with the next 20 ms.
x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))
predictions = trainer.predict(x_test)
bp.losses.mean_squared_error(y_test, predictions)
Array(5.36414848e-10, dtype=float64)
def plot_difference(truths, predictions):
truths = bm.as_numpy(truths)
predictions = bm.as_numpy(predictions)
plt.subplot(311)
plt.plot(truths[0, :, 0], label='Ground Truth')
plt.plot(predictions[0, :, 0], label='Prediction')
plt.ylabel('x')
plt.legend()
plt.subplot(312)
plt.plot(truths[0, :, 1], label='Ground Truth')
plt.plot(predictions[0, :, 1], label='Prediction')
plt.ylabel('y')
plt.legend()
plt.subplot(313)
plt.plot(truths[0, :, 2], label='Ground Truth')
plt.plot(predictions[0, :, 2], label='Prediction')
plt.ylabel('z')
plt.legend()
plt.show()
plot_difference(y_test, predictions)

We can make the task harder to forecast 10 time step ahead.
warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])
x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)
plot_difference(y_test, predictions)

Or forecast 100 time step ahead.
warmup_data = get_subset(data, 0, int(20/dt))
_ = trainer.predict(warmup_data)
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])
x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)
plot_difference(y_test, predictions)

As you see, forecasting larger time step makes the learning more difficult.
Training an artificial recurrent network#
In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.train.BPTT
interface.
Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we have a time series of noise data,
noises = bm.random.normal(0, 0.2, size=10)
plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt
:
dt = 0.1
integrals = bm.cumsum(noises) * dt
plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()

Here, we first define a task which generates the input data and the target integration results.
from functools import partial
dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128
@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def build_inputs_and_targets(mean=0.025, scale=0.01):
# Create the white noise input
sample = bm.random.normal(size=(num_batch, 1, 1))
bias = mean * 2.0 * (sample - 0.5)
samples = bm.random.normal(size=(num_batch, num_step, 1))
noise_t = scale / dt ** 0.5 * samples
inputs = bias + noise_t
targets = bm.cumsum(inputs, axis=1)
return inputs, targets
def train_data():
for _ in range(100):
yield build_inputs_and_targets()
Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True
for the used VanillaRNN
instance.
class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.layers.Dense(num_hidden, 1)
def update(self, sha, x):
# "sha" is the arguments shared across all nodes.
return self.out(sha, self.rnn(sha, x))
with bm.training_environment():
model = RNN(1, 100)
brainpy.nn.BPTT
trainer receives a loss
function setting, and an optimizer
setting. Loss function can be selected from the brainpy.losses
module, or it can be a callable function receives (predictions, targets)
argument. Optimizer setting must be an instance of brainpy.optim.Optimizer
.
Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.
# define loss function
def loss(predictions, targets, l2_reg=2e-4):
mse = bp.losses.mean_squared_error(predictions, targets)
l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.BPTT(model,
loss_fun=loss,
optimizer=opt)
# train the model
trainer.fit(train_data, num_epoch=30)
Train 0 epoch, use 2.2865 s, loss 0.3474464803593203
Train 1 epoch, use 0.9384 s, loss 0.026605883508514516
Train 2 epoch, use 0.9478 s, loss 0.021708405535614907
Train 3 epoch, use 0.9892 s, loss 0.02143528795935897
Train 4 epoch, use 0.9300 s, loss 0.02107475878707875
Train 5 epoch, use 0.9298 s, loss 0.020932997073748006
Train 6 epoch, use 0.9566 s, loss 0.020855205349191275
Train 7 epoch, use 0.9072 s, loss 0.020789264013002805
Train 8 epoch, use 0.9631 s, loss 0.02066686516861348
Train 9 epoch, use 0.8968 s, loss 0.020621776997250745
Train 10 epoch, use 0.9438 s, loss 0.020569378459587808
Train 11 epoch, use 0.8785 s, loss 0.02049849876797083
Train 12 epoch, use 0.8515 s, loss 0.02047079743844964
Train 13 epoch, use 0.8533 s, loss 0.02039058010677752
Train 14 epoch, use 0.9209 s, loss 0.02035540442302181
Train 15 epoch, use 0.9844 s, loss 0.0203037559193207
Train 16 epoch, use 1.0006 s, loss 0.02025348558545429
Train 17 epoch, use 0.9397 s, loss 0.020213293431676486
Train 18 epoch, use 0.8990 s, loss 0.0201581882182178
Train 19 epoch, use 0.8812 s, loss 0.020140154456911596
Train 20 epoch, use 0.8976 s, loss 0.020074232282055494
Train 21 epoch, use 0.8953 s, loss 0.020042431126861646
Train 22 epoch, use 0.8584 s, loss 0.020006458781740344
Train 23 epoch, use 0.9393 s, loss 0.019969488637816196
Train 24 epoch, use 0.9012 s, loss 0.019924267557334032
Train 25 epoch, use 0.8897 s, loss 0.019882438760702597
Train 26 epoch, use 0.9212 s, loss 0.019851161733026937
Train 27 epoch, use 0.8935 s, loss 0.019828587850703315
Train 28 epoch, use 0.9329 s, loss 0.01978077887574506
Train 29 epoch, use 0.8505 s, loss 0.019749579454078393
The training losses can be retrieved by .get_hist_metric()
function.
plt.figure(figsize=(8, 3))
plt.plot(trainer.get_hist_metric(metric='loss'))
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()

Finally, let’s try the trained network, and test whether it can generate the correct integration results.
model.reset_state(num_batch)
x, y = build_inputs_and_targets()
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()

Training a spiking neural network#
BrainPy also supports to train spiking neural networks.
In the following, we demonstrate how to use back-propagation algorithms to train spiking neurons with a simple example.
Our model is a simple three layer model:
an input layer
a LIF layer
a readout layer
The synaptic connection between each layer is the Exponenetial synapse model.
class SNN(bp.Network):
def __init__(self, num_in, num_rec, num_out):
super(SNN, self).__init__()
# parameters
self.num_in = num_in
self.num_rec = num_rec
self.num_out = num_out
# neuron groups
self.i = bp.neurons.InputGroup(num_in)
self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)
# synapse: i->r
self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), tau=10.,
output=bp.synouts.CUBA(target_var=None),
g_max=bp.init.KaimingNormal(scale=20.))
# synapse: r->o
self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), tau=10.,
output=bp.synouts.CUBA(target_var=None),
g_max=bp.init.KaimingNormal(scale=20.))
# whole model
self.model = bp.Sequential(self.i, self.i2r, self.r, self.r2o, self.o)
def update(self, tdi, spike):
self.model(tdi, spike)
return self.o.V.value
with bm.training_environment():
net = SNN(100, 10, 2) # out task is a two label classification task
We try to use this simple task to classify a random spiking data into two classes.
num_step = 2000
num_sample = 256
freq = 5 # Hz
mask = bm.random.rand(num_sample, num_step, net.num_in)
x_data = bm.zeros((num_sample, num_step, net.num_in))
x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)
def get_data():
for _ in range(1):
yield x_data, y_data
Same as the training of artificial recurrent neural networks, we use Adam optimizer and cross entropy loss to train the model.
opt = bp.optim.Adam(lr=2e-3)
def loss(predicts, targets):
return bp.losses.cross_entropy_loss(bm.max(predicts, axis=1), targets)
trainer = bp.BPTT(net,
loss_fun=loss,
optimizer=opt)
trainer.fit(train_data=get_data,
num_report=10,
num_epoch=200)
Train 10 steps, use 0.5471 s, loss 0.8119578839372008
Train 20 steps, use 0.5587 s, loss 0.7544552171954115
Train 30 steps, use 0.5898 s, loss 0.7020412181865148
Train 40 steps, use 0.5782 s, loss 0.6750034517542842
Train 50 steps, use 0.5996 s, loss 0.6638762207402613
Train 60 steps, use 0.5469 s, loss 0.6457876645959164
Train 70 steps, use 0.5286 s, loss 0.625173858852444
Train 80 steps, use 0.5480 s, loss 0.6108308776234428
Train 90 steps, use 0.5672 s, loss 0.5928884702255813
Train 100 steps, use 0.5597 s, loss 0.5864131114195933
Train 110 steps, use 0.5726 s, loss 0.5829393663656248
Train 120 steps, use 0.5730 s, loss 0.5648404392129361
Train 130 steps, use 0.5711 s, loss 0.5471623937603349
Train 140 steps, use 0.5665 s, loss 0.5335702569949301
Train 150 steps, use 0.5618 s, loss 0.5198388588329248
Train 160 steps, use 0.5529 s, loss 0.5154762874039034
Train 170 steps, use 0.5734 s, loss 0.49874691621484557
Train 180 steps, use 0.5523 s, loss 0.4925224810657929
Train 190 steps, use 0.5375 s, loss 0.48411000993072006
Train 200 steps, use 0.5524 s, loss 0.46809022931937005
The training loss is continuously decreasing, demonstrating that the network is effectively training.
# visualize the training losses
plt.plot(trainer.get_hist_metric())
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.show()

Let’s visualize the trained spiking neurons.
import numpy as np
from matplotlib.gridspec import GridSpec
def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
plt.figure(figsize=(15, 8))
gs = GridSpec(*dim)
mem = 1. * mem
if spk is not None:
mem[spk > 0.0] = spike_height
mem = bm.as_numpy(mem)
for i in range(np.prod(dim)):
if i == 0:
a0 = ax = plt.subplot(gs[i])
else:
ax = plt.subplot(gs[i], sharey=a0)
ax.plot(mem[i])
plt.tight_layout()
plt.show()
# get the prediction results and neural activity
runner = bp.DSRunner(
net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}
)
out = runner.run(inputs=x_data, reset_state=True)
plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))

# the prediction accuracy
m = bm.max(out, axis=1) # max over time
am = bm.argmax(m, axis=1) # argmax over output units
acc = bm.mean(y_data == am) # compare to labels
print("Accuracy %.3f" % acc)
Accuracy 0.672
Analyzing a Brain Dynamics Model#
In BrainPy, defined models can not only be used for simulation, but also be capable of performing automatic dynamics analysis.
BrainPy provides rich interfaces to support analysis, including
Phase plane analysis, bifurcation analysis, and fast-slow bifurcation analysis for low-dimensional systems;
linearization analysis and fixed/slow point finding for high-dimensional systems.
Here we will introduce three brief examples of 1-D bifurcation analysis and 2-D phase plane analysis. For more detailsand more examples, please refer to the tutorials of dynamics analysis.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bm.enable_x64() # it's better to use x64 computation
bp.__version__
'2.3.1'
Bifurcation analysis of a 1D model#
Here, we demonstrate how to perform a bifurcation analysis through a one-dimensional neuron model.
Let’s try to analyze how the external input influences the dynamics of the Exponential Integrate-and-Fire (ExpIF) model. The ExpIF model is a one-variable neuron model whose dynamics is defined by:
We can analyze the change of \({\dot {V}}\) with respect to \(V\). First, let’s generate an ExpIF model using pre-defined modules in brainpy.dyn
:
expif = bp.neurons.ExpIF(1, delta_T=1.)
The default value of other parameters can be accessed directly by their names:
expif.V_rest, expif.V_T, expif.R, expif.tau
(-65.0, -59.9, 1.0, 10.0)
After defining the model, we can use it for bifurcation analysis.
bif = bp.analysis.Bifurcation1D(
model=expif,
target_vars={'V': [-70., -55.]},
target_pars={'I_ext': [0., 6.]},
resolutions={'I_ext': 0.01}
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...

In the Bifurcation1D
analyzer, model
refers to the model to be analyzed (essentially the analyzer will access the derivative function in the model), target_vars
denotes the target variables, target_pars
denotes the changing parameters, and resolution
determines the resolution of the analysis.
In the image above, there are two lines that “merge” together to form a bifurcation. The dots making up the lines refer to the fixed points of \(\mathrm{d}V/\mathrm{d}t\). On the left of the bifurcation point (where two lines merge together), there are two fixed points where \(\mathrm{d}V/\mathrm{d}t = 0\) given each external input \(I_\mathrm{ext}\). One of them is a stable point, and the other is an unstable one. When \(I_\mathrm{ext}\) increases, the two fixed points move closer to each other, overlap, and finally disappear.
Bifurcation analysis provides insights for the dynamics of the model, for it indicates the number and the change of stable states with respect to different parameters.
Phase plane analysis of a 2D model#
Besides bifurcationi analysis, another important tool is phase plane analysis, which displays the trajectory of the variable point in the vector field. Let’s take the FitzHugh–Nagumo (FHN) neuron model as an example. The dynamics of the FHN model is given by:
Users can easily define a FHN model which is also provided by BrainPy:
fhn = bp.neurons.FHN(1)
Because there are two variables, \(v\) and \(w\), in the FHN model, we shall use 2-D phase plane analysis to visualize how these two variables change over time.
analyzer = bp.analysis.PhasePlane2D(
model=fhn,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
pars_update={'I_ext': 0.8},
resolutions=0.01,
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
analyzer.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 866 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 V=-0.2729223248464073, w=0.5338542697673022 is a unstable node.
I am plotting the trajectory ...

In the PhasePlane2D
analyzer, the parameters model
, target_vars
, and resolution
is the same as those in Bifurcation1D
. pars_update
specifies the parameters to be updated during analysis. After defining the analyzer, users can visualize the nullcline, vector field, fixed points and the trajectory in the image. The phase plane gives users intuitive interpretation of the changes of \(v\) and \(w\) guided by the vector field (violet arrows).
Slow point analysis of a high-dimensional system#
BrainPy is also capable of performing fixed/slow point analysis of high-dimensional systems. Moreover, it can perform automatic linearization analysis around the fixed point.
In the following, we use a gap junction coupled FitzHugh–Nagumo (FHN) network as an example to demonstrate how to find fixed/slow points of a high-dimensional system.
We first define the gap junction coupled FHN network as the normal DynamicalSystem
class.
class GJCoupledFHN(bp.DynamicalSystem):
def __init__(self, num=4, method='exp_auto'):
super(GJCoupledFHN, self).__init__()
# parameters
self.num = num
self.a = 0.7
self.b = 0.8
self.tau = 12.5
self.gjw = 0.0001
# variables
self.V = bm.Variable(bm.random.uniform(-2, 2, num))
self.w = bm.Variable(bm.random.uniform(-2, 2, num))
self.Iext = bm.Variable(bm.zeros(num))
# functions
self.int_V = bp.odeint(self.dV, method=method)
self.int_w = bp.odeint(self.dw, method=method)
def dV(self, V, t, w, Iext=0.):
gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw
dV = V - V * V * V / 3 - w + Iext + gj
return dV
def dw(self, w, t, V):
dw = (V + self.a - self.b * w) / self.tau
return dw
def update(self, tdi):
t, dt = tdi.get('t'), tdi.get('dt')
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
self.w.value = self.int_w(self.w, t, self.V, dt)
self.Iext[:] = 0.
Through simulation, we can easily find that this system has a limit cycle attractor, implying that an unstable fixed point exists.
# initialize a network
model = GJCoupledFHN(4)
model.gjw = 0.1
# simulation with an input
Iext = bm.asarray([0., 0., 0., 0.6])
runner = bp.DSRunner(model, monitors=['V'], inputs=['Iext', Iext])
runner.run(300.)
# visualization
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',
plot_ids=list(range(model.num)),
show=True)

Let’s try to optimize the fixed points for this system. Note that we only take care of the variables V
and w
. Different from the low-dimensional analyzer, we should provide the candidate fixed points or initial fixed points when using the high-dimensional analyzer.
# init a slow point finder
finder = bp.analysis.SlowPointFinder(f_cell=model,
target_vars={'V': model.V, 'w': model.w},
inputs=[model.Iext, Iext])
# optimize to find fixed points
finder.find_fps_with_gd_method(
candidates={'V': bm.random.normal(0., 2., (1000, model.num)),
'w': bm.random.normal(0., 2., (1000, model.num))},
tolerance=1e-6,
num_batch=200,
optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)),
)
# filter fixed points whose loss is bigger than the threshold
finder.filter_loss(1e-8)
# remove the duplicate fixed points
finder.keep_unique()
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.29 sec, Training loss 0.0003104926
Batches 201-400 in 0.28 sec, Training loss 0.0002287778
Batches 401-600 in 0.28 sec, Training loss 0.0001775225
Batches 601-800 in 0.30 sec, Training loss 0.0001401555
Batches 801-1000 in 0.30 sec, Training loss 0.0001119446
Batches 1001-1200 in 0.30 sec, Training loss 0.0000904519
Batches 1201-1400 in 0.30 sec, Training loss 0.0000738873
Batches 1401-1600 in 0.30 sec, Training loss 0.0000609509
Batches 1601-1800 in 0.30 sec, Training loss 0.0000506783
Batches 1801-2000 in 0.36 sec, Training loss 0.0000424477
Batches 2001-2200 in 0.29 sec, Training loss 0.0000357793
Batches 2201-2400 in 0.29 sec, Training loss 0.0000303206
Batches 2401-2600 in 0.33 sec, Training loss 0.0000258537
Batches 2601-2800 in 0.28 sec, Training loss 0.0000221875
Batches 2801-3000 in 0.29 sec, Training loss 0.0000191505
Batches 3001-3200 in 0.30 sec, Training loss 0.0000166231
Batches 3201-3400 in 0.30 sec, Training loss 0.0000144943
Batches 3401-3600 in 0.29 sec, Training loss 0.0000126804
Batches 3601-3800 in 0.30 sec, Training loss 0.0000111463
Batches 3801-4000 in 0.29 sec, Training loss 0.0000098656
Batches 4001-4200 in 0.28 sec, Training loss 0.0000087958
Batches 4201-4400 in 0.35 sec, Training loss 0.0000078796
Batches 4401-4600 in 0.31 sec, Training loss 0.0000070861
Batches 4601-4800 in 0.29 sec, Training loss 0.0000063897
Batches 4801-5000 in 0.28 sec, Training loss 0.0000057697
Batches 5001-5200 in 0.28 sec, Training loss 0.0000052188
Batches 5201-5400 in 0.28 sec, Training loss 0.0000047263
Batches 5401-5600 in 0.29 sec, Training loss 0.0000042864
Batches 5601-5800 in 0.28 sec, Training loss 0.0000038972
Batches 5801-6000 in 0.28 sec, Training loss 0.0000035515
Batches 6001-6200 in 0.29 sec, Training loss 0.0000032389
Batches 6201-6400 in 0.29 sec, Training loss 0.0000029477
Batches 6401-6600 in 0.28 sec, Training loss 0.0000026731
Batches 6601-6800 in 0.28 sec, Training loss 0.0000024145
Batches 6801-7000 in 0.28 sec, Training loss 0.0000021735
Batches 7001-7200 in 0.35 sec, Training loss 0.0000019521
Batches 7201-7400 in 0.28 sec, Training loss 0.0000017512
Batches 7401-7600 in 0.28 sec, Training loss 0.0000015672
Batches 7601-7800 in 0.28 sec, Training loss 0.0000013971
Batches 7801-8000 in 0.27 sec, Training loss 0.0000012403
Batches 8001-8200 in 0.27 sec, Training loss 0.0000010954
Batches 8201-8400 in 0.27 sec, Training loss 0.0000009603
Stop optimization as mean training loss 0.0000009603 is below tolerance 0.0000010000.
Excluding fixed points with squared speed above tolerance 1e-08:
Kept 815/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
Kept 1/815 unique fixed points with uniqueness tolerance 0.025.
print('fixed points:', )
finder.fixed_points
fixed points:
{'V': array([[-1.17757852, -1.17757852, -1.17757852, -0.81465053]]),
'w': array([[-0.59697314, -0.59697314, -0.59697314, -0.14331316]])}
print('fixed point losses:', )
finder.losses
fixed point losses:
array([2.46519033e-32])
Let’s perform the linearization analysis of the found fixed points, and visualize its decomposition results.
_ = finder.compute_jacobians(finder.fixed_points, plot=True)

This is an unstable fixed point, because one of its eigenvalues has the real part bigger than 1.
Further reading#
For more details about how to perform bifurcation analysis and phase plane analysis, please see the tutorial of Low-dimensional Analyzers.
A good example of phase plane analysis and bifurcation analysis is the decision-making model, please see the tutorial in Analysis of a Decision-making Model
If you want to how to analyze the slow points (or fixed points) of your high-dimensional dynamical models, please see the tutorial of High-dimensional Analyzers
Concept 1: Object-oriented Transformation#
Most computation in BrainPy relies on JAX. JAX has provided wonderful transformations, including differentiation, vecterization, parallelization and just-in-time compilation, for Python programs. If you are not familiar with it, please see its documentation.
However, JAX only supports functional programming, i.e., transformations for Python functions. This is not what we want. Brain Dynamics Modeling need object-oriented programming.
To meet this requirement, BrainPy defines the interface for object-oriented (OO) transformations. These OO transformations can be easily performed for BrainPy objects.
In this section, let’s talk about the BrainPy concept of object-oriented transformations.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.0'
Illustrating example: Training a network#
To illustrate this concept, we need a demonstration example. Here, we choose the popular neural network training as the illustrating case.
In this training case, we want to teach the neural network to correctly classify a random array as two labels (True
or False
). That is, we have the training data:
num_in = 100
num_sample = 256
X = bm.random.rand(num_sample, num_in)
Y = (bm.random.rand(num_sample) < 0.5).astype(float)
We use a two-layer feedforward network:
class Linear(bp.BrainPyObject):
def __init__(self, n_in, n_out):
super().__init__()
self.num_in = n_in
self.num_out = n_out
init = bp.init.XavierNormal()
self.W = bm.Variable(init((n_in, n_out)))
self.b = bm.Variable(bm.zeros((1, n_out)))
def __call__(self, x):
return x @ self.W + self.b
net = bp.Sequential(Linear(num_in, 20),
bm.relu,
Linear(20, 2))
print(net)
Sequential(
[0] Linear0
[1] relu
[2] Linear1
)
Here, we use a supervised learning training paradigm.
rng = bm.random.RandomState(123)
# Loss function
@bm.to_object(child_objs=net, dyn_vars=rng)
def loss():
# shuffle the data
key = rng.split_key()
x_data = rng.permutation(X, key=key)
y_data = rng.permutation(Y, key=key)
# prediction
predictions = net(dict(), x_data)
# loss
l = bp.losses.cross_entropy_loss(predictions, y_data)
return l
# Gradient function
grad = bm.grad(loss, grad_vars=net.vars(), return_value=True)
# Optimizer
optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())
# Training step
@bm.to_object(child_objs=(grad, optimizer))
def train(i):
grads, l = grad()
optimizer.update(grads)
return l
num_step = 400
for i in range(0, 4000, num_step):
# train 400 steps once
ls = bm.for_loop(train, operands=bm.arange(i, i + num_step))
print(f'Train {i + num_step} epoch, loss = {bm.mean(ls):.4f}')
Train 400 epoch, loss = 0.6710
Train 800 epoch, loss = 0.5992
Train 1200 epoch, loss = 0.5332
Train 1600 epoch, loss = 0.4720
Train 2000 epoch, loss = 0.4189
Train 2400 epoch, loss = 0.3736
Train 2800 epoch, loss = 0.3335
Train 3200 epoch, loss = 0.2972
Train 3600 epoch, loss = 0.2644
Train 4000 epoch, loss = 0.2346
In the above example, we have seen classical elements in a neural network training, such as
net
: neural networkloss
: loss functiongrad
: gradient functionoptimizer
: parameter optimizertrain
: training step
In BrainPy, all these elements can be defined as class objects and can be used for performing OO transformations.
In essence, the concept of BrainPy object-oriented transformation has three components:
BrainPyObject
: the base class for object-oriented programmingVariable
: the varibles in the class object, whose values are ready to be changed/updated during transformationObjectTransform
: the transformations for computation involvingBrainPyObject
andVariable
BrainPyObject
and its Variable
#
BrainPyObject
is the base class for object-oriented programming in BrainPy.
It can be viewed as a container which contains all needed Variable for our computation.
In the above example, Linear
object has two Variable
: W and b. The net
we defined is further composed of two Linear
objects. We can expect that four variables can be retrieved from it.
net.vars().keys()
dict_keys(['Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])
An important question is, how to define Variable
in a BrainPyObject
so that we can retrieve all of them?
Actually, all Variable instance which can be accessed by self.
attribue can be retrived from a BrainPyObject
recursively.
No matter how deep the composition of BrainPyObject
, once BrainPyObject
instance and their Variable
instances can be accessed by self.
operation, all of them will be retrieved.
class SuperLinear(bp.BrainPyObject):
def __init__(self, ):
super().__init__()
self.l1 = Linear(10, 20)
self.v1 = bm.Variable(3)
sl = SuperLinear()
# retrieve Variable
sl.vars().keys()
dict_keys(['SuperLinear0.v1', 'Linear2.W', 'Linear2.b'])
# retrieve BrainPyObject
sl.nodes().keys()
dict_keys(['SuperLinear0', 'Linear2'])
However, we cannot access the BrainPyObject
or Variable
which is in a Python container (like tuple, list, or dict). For this case, we can register our objects and variables through .register_implicit_vars()
and .register_implicit_nodes()
:
class SuperSuperLinear(bp.BrainPyObject):
def __init__(self, register=False):
super().__init__()
self.ss = [SuperLinear(), SuperLinear()]
self.vv = {'v_a': bm.Variable(3)}
if register:
self.register_implicit_nodes(self.ss)
self.register_implicit_vars(self.vv)
# without register
ssl = SuperSuperLinear(register=False)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys([])
dict_keys(['SuperSuperLinear0'])
# with register
ssl = SuperSuperLinear(register=True)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys(['SuperSuperLinear1.v_a', 'SuperLinear3.v1', 'SuperLinear4.v1', 'Linear5.W', 'Linear5.b', 'Linear6.W', 'Linear6.b'])
dict_keys(['SuperSuperLinear1', 'SuperLinear3', 'SuperLinear4', 'Linear5', 'Linear6'])
Transform a function to BrainPyObject
#
Let’s go back to our network training.
After the definition of net
, we further define a loss
function whose computation involves the net
object for neural network prediction and a rng
Variable for data shuffling.
This Python function is then transformed into a BrainPyObject
instance by brainpy.math.to_object
interface.
loss
FunAsObject(nodes=[Sequential0],
num_of_vars=1)
All Variable
used in this instance can also be retrieved through:
loss.vars().keys()
dict_keys(['loss0._var0', 'Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])
Note that, when using to_object()
, we need to explicitly declare all BrainPyObject
and Variable
used in this Python function.
Due to the recursive retrieval property of BrainPyObject
, we only need to specify the latest composition object.
In the above loss
object, we do not need to specify two Linear
object. Instead, we only need to give the top level object net
into to_object()
transform.
Similarly, when we transform train
function into a BrainPyObject
, we just need to point out the grad
and opt
we have used, rather than the previous loss, net or rng.
BrainPy object-oriented transformations#
BrainPy object-oriented transformations are designed to work on BrainPyObject
.
These transformations include autograd brainpy.math.grad()
and JIT brainpy.math.jit()
.
In our case, we used two OO transformations provided in BrainPy.
First, grad
object is defined with the loss
function. Within it, we need to specify what variables we need to compute their gradients through grad_vars
.
Note that, the OO transformation of any BrainPyObject
results in another BrainPyObject
object. Therefore, it can be recersively used as a component to form the larger scope of object-oriented programming and object-oriented transformation.
grad
GradientTransform(target=loss0,
num_of_grad_vars=4,
num_of_dyn_vars=1)
Next, we train 400 steps once by using a for_loop
transformation. Different from grad
which return a BrainPyObject
instance, for_loop
direactly returns the loop results.
Concept 2: Dynamical System#
BrainPy supports modelings in brain simulation and brain-inspired computing.
All these supports are based on one common concept: Dynamical System via brainpy.DynamicalSystem
.
Therefore, it is essential to understand:
what is
brainpy.DynamicalSystem
?how to define
brainpy.DynamicalSystem
?how to run
brainpy.DynamicalSystem
?
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.3.8'
What is DynamicalSystem
?#
All models used in brain simulation and brain-inspired computing is DynamicalSystem
.
Note
DynamicalSystem
is a subclass of BrainPyOject
. Therefore it supports to use object-oriented transformations as stated in the previous tutorial.
A DynamicalSystem
defines the updating rule of the model at single time step.
For models with state,
DynamicalSystem
defines the state transition from \(t\) to \(t+dt\), i.e., \(S(t+dt) = F\left(S(t), x, t, dt\right)\), where \(S\) is the state, \(x\) is input, \(t\) is the time, and \(dt\) is the time step. This is the case for recurrent neural networks (like GRU, LSTM), neuron models (like HH, LIF), or synapse models which are widely used in brain simulation.However, for models in deep learning, like convolution and fully-connected linear layers,
DynamicalSystem
defines the input-to-output mapping, i.e., \(y=F\left(x, t\right)\).
How to define DynamicalSystem
?#
Keep in mind that the usage of DynamicalSystem
has several constraints in BrainPy.
1. .update()
function#
First, all DynamicalSystem
should implement .update()
function, which receives two arguments:
class YourModel(bp.DynamicalSystem):
def update(self, s, x):
pass
s
(or named as others): A dict, to indicate shared arguments across all nodes/layers in the network, likethe current time
t
, orthe current running index
i
, orthe current time step
dt
, orthe current phase of training or testing
fit=True/False
.
x
(or named as others): The individual input for this node/layer.
We call s
as shared arguments because they are same and shared for all nodes/layers. On the contrary, different nodes/layers have different input x
.
However, users can also use the simplified DynamicalSystemNS
to define models, which only receives one argument:
class YourModel(bp.DynamicalSystemNS):
def update(self, x):
pass
Example: LIF neuron model for brain simulation
Here we illustrate the first constraint of DynamicalSystem
using the Leaky Integrate-and-Fire (LIF) model.
The LIF model is firstly proposed in brain simulation for modeling neuron dynamics. Its equation is given by
For the details of the model, users should refer to Wikipedia or other resource.
class LIF_for_BrainSimulation(bp.DynamicalSystemNS):
def __init__(self, size, V_rest=0., V_th=1., tau=5., mode=None):
super().__init__(mode=mode)
# this model only supports non-batching mode
bp.check.is_subclass(self.mode, bm.NonBatchingMode)
# parameters
self.size = size
self.V_rest = V_rest
self.V_th = V_th
self.tau = tau
# variables
self.V = bm.Variable(bm.ones(size) * V_rest)
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
# integrate differential equation with exponential euler method
self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')
def update(self, s, x):
# define how the model states update
# according to the external input
t = bp.share.load('t')
dt = bp.share.load('dt')
V = self.integral(self.V, t, x, dt=dt)
spike = V >= self.V_th
self.V.value = bm.where(spike, self.V_rest, V)
self.spike.value = spike
return spike
2. Computing mode#
Second, explicitly consider which computing mode your DynamicalSystem
supports.
Brain simulation usually builds models without batching dimension (we refer to it as non-batching mode, as seen in above LIF model), while brain-inspired computation trains models with a batch of data (batching mode or training mode).
So, to write a model applicable to abroad applications in brain simulation and brain-inspired computing, you need to consider which mode your model supports, one of them, or both of them.
Example: LIF neuron model for both brain simulation and brain-inspired computing
When considering the computing mode, we can program a general LIF model for brain simulation and brain-inspired computing.
To overcome the non-differential property of the spike in the LIF model for brain simulation, i.e., at the code of
spike = V >= self.V_th
LIF models used in brain-inspired computing calculate the spiking state using the surrogate gradient function. Usually, we replace the backward gradient of the spike with a smooth function, like
class LIF(bp.DynamicalSystemNS):
def __init__(self, size, f_surrogate=None, V_rest=0., V_th=1., tau=5.,mode=None):
super().__init__(mode=mode)
bp.check.is_subclass(self.mode, [bm.NonBatchingMode, bm.BatchingMode, bm.TrainingMode])
# Parameters
self.size = size
self.num = bp.tools.size2num(size)
self.V_rest = V_rest
self.V_th = V_th
self.tau = tau
if f_surrogate is None:
f_surrogate = bm.surrogate.inv_square_grad
self.f_surrogate = f_surrogate
# integrate differential equation with exponential euler method
self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')
# Initialize a Variable:
# - if non-batching mode, batch axis of V is None
# - if batching mode, batch axis of V is 0
self.V = bp.init.variable_(bm.zeros, self.size, self.mode)
self.V[:] = self.V_rest
self.spike = bp.init.variable_(bm.zeros, self.size, self.mode)
def reset_state(self, batch_size=None):
self.V.value = bp.init.variable_(bm.ones, self.size, batch_size) * self.V_rest
self.spike.value = bp.init.variable_(bm.zeros, self.size, batch_size)
def update(self, x):
t = bp.share.load('t')
dt = bp.share.load('dt')
V = self.integral(self.V, t, x, dt=dt)
# replace non-differential heaviside function
# with a surrogate gradient function
spike = self.f_surrogate(V - self.V_th)
# reset membrane potential
self.V.value = (1. - spike) * V + spike * self.V_rest
self.spike.value = spike
return spike
Model composition#
The LIF
model we have defined above can be recursively composed to construct networks in brain simulation and brain-inspired computing.
The following code snippet utilizes the LIF model to build an E/I balanced network EINet
, which is a classical network model in brain simulation.
class EINet(bp.DynamicalSystemNS):
def __init__(self, num_exc, num_inh):
super().__init__()
self.E = LIF(num_exc, V_rest=-55, V_th=-50., tau=20.)
self.I = LIF(num_inh, V_rest=-55, V_th=-50., tau=20.)
self.E2E = bp.experimental.Exponential(bp.conn.FixedProb(0.02, pre=num_exc, post=num_exc), g_max=1.62, tau=5.)
self.E2I = bp.experimental.Exponential(bp.conn.FixedProb(0.02, pre=num_exc, post=num_inh), g_max=1.62, tau=5.)
self.I2E = bp.experimental.Exponential(bp.conn.FixedProb(0.02, pre=num_inh, post=num_exc), g_max=-9.0, tau=10.)
self.I2I = bp.experimental.Exponential(bp.conn.FixedProb(0.02, pre=num_inh, post=num_inh), g_max=-9.0, tau=10.)
def update(self, x):
# x is the background input
e2e = self.E2E(self.E.spike)
e2i = self.E2I(self.E.spike)
i2e = self.I2E(self.I.spike)
i2i = self.I2I(self.I.spike)
self.E(e2e + i2e + x)
self.I(e2i + i2i + x)
with bm.environment(mode=bm.nonbatching_mode):
net1 = EINet(3200, 800)
Moreover, our LIF model can also be used in brain-inspired computing scenario. The following AINet
uses the LIF model to construct a model for AI training.
# This network can be used in AI applications
class AINet(bp.DynamicalSystemNS):
def __init__(self, sizes):
super().__init__()
self.neu1 = LIF(sizes[0])
self.syn1 = bp.layers.Dense(sizes[0], sizes[1])
self.neu2 = LIF(sizes[1])
self.syn2 = bp.layers.Dense(sizes[1], sizes[2])
self.neu3 = LIF(sizes[2])
def update(self, x):
return x >> self.neu1 >> self.syn1 >> self.neu2 >> self.syn2 >> self.neu3
with bm.environment(mode=bm.training_mode):
net2 = AINet([100, 50, 10])
How to run DynamicalSystem
?#
As we have stated above that DynamicalSystem
only defines the updating rule at single time step, to run a DynamicalSystem
instance over time, we need a for loop mechanism.
1. brainpy.math.for_loop
#
for_loop
is a structural control flow API which runs a function with the looping over the inputs. Moreover, this API just-in-time compile the looping process into the machine code.
Suppose we have 200 time steps with the step size of 0.1, we can run the model with:
def run_net2(t, currents):
bp.share.save(t=t)
return net2(currents)
with bm.environment(dt=0.1):
# construct a set of shared argument with the given time steps
shared = bm.shared_args_over_time(num_step=200)
# construct the inputs with shape of (time, batch, feature)
currents = bm.random.rand(200, 10, 100)
# run the model
net2.reset_state(batch_size=10)
out = bm.for_loop(run_net2, (shared, currents), child_objs=net2)
out.shape
(200, 10, 10)
2. brainpy.LoopOverTime
#
Different from for_loop
, brainpy.LoopOverTime
is used for constructing a dynamical system that automatically loops the model over time when receiving an input.
for_loop
runs the model over time. While brainpy.LoopOverTime
creates a model which will run the model over time when calling it.
net2.reset_state(batch_size=10)
looper = bp.LoopOverTime(net2)
out = looper(currents)
out.shape
(200, 10, 10)
3. brainpy.DSRunner
#
Another way to run the model in BrainPy is using the structural running object DSRunner
and DSTrainer
. They provide more flexible way to monitoring the variables in a DynamicalSystem
. The details users should refer to the DSRunner tutorial.
with bm.environment(dt=0.1):
runner = bp.DSRunner(net1, monitors={'E.spike': net1.E.spike, 'I.spike': net1.I.spike})
runner.run(inputs=bm.ones(1000) * 20.)
bp.visualize.raster_plot(runner.mon['ts'], runner.mon['E.spike'])

Math Foundation#
brainpy.math.Variable
#
In BrainPy, the JIT compilation for class objects relies on brainpy.math.Variable
. In this section, we are going to understand:
What is a
Variable
?What are the subtypes of
Variable
?How to update a
Variable
?
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
brainpy.math.Variable
#
brainpy.math.Variable
is a pointer referring to a JAX Array. It stores an array as its value. The data in a Variable can be changed during our object-oriented JIT compilation. If an array is labeled as a Variable, it means that it is a dynamical variable that changes over time.
Arrays that are not marked as Variables will be JIT compiled as static data. Modifications of these arrays will be invalid or cause an error.
Creating a Variable
Passing a tensor into the brainpy.math.Variable
creates a Variable. For example:
b1 = bm.random.random(5)
b1
Array([0.70490587, 0.9825947 , 0.79977 , 0.21864283, 0.70959914], dtype=float32)
b2 = bm.Variable(b1)
b2
Variable([0.70490587, 0.9825947 , 0.79977 , 0.21864283, 0.70959914], dtype=float32)
Accessing the value in a Variable
The data in a Variable can be obtained through .value
.
b2.value
Array([0.70490587, 0.9825947 , 0.79977 , 0.21864283, 0.70959914], dtype=float32)
(b2.value == b1).all()
Array(True, dtype=bool)
Supported operations on Variables
Variables support almost all the operations for a JAX array.
b2 + 1.
Array([1.7049059, 1.9825947, 1.79977 , 1.2186428, 1.7095991], dtype=float32)
b2 ** 2
Array([0.49689227, 0.9654924 , 0.63963205, 0.04780469, 0.5035309 ], dtype=float32)
bm.floor(b2)
Array([0., 0., 0., 0., 0.], dtype=float32)
Subtypes of Variable
#
brainpy.math.Variable
has several subtypes, including brainpy.math.TrainVar
, brainpy.math.Parameter
, and brainpy.math.RandomState
. Subtypes can also be customized and extended by users.
1. TrainVar#
brainpy.math.TrainVar
is a trainable variable and a subclass of brainpy.math.Variable
. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.
b = bm.random.rand(4)
b
Array([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)
bm.TrainVar(b)
TrainVar([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)
2. Parameter#
brainpy.math.Parameter
is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable
. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subset
method.
b = bm.random.rand(1)
b
Array([0.47972953], dtype=float32)
bm.Parameter(b)
Parameter([0.47972953], dtype=float32)
3. RandomState#
brainpy.math.random.RandomState
is also a subclass of brainpy.math.Variable
. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.
state = bm.random.RandomState(1234)
state
RandomState(key=([ 0, 1234], dtype=uint32))
# perform a "random" sampling
state.random(1)
state # the value changed
RandomState(key=([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling
state.sample(1)
state # the value changed too
RandomState(key=([1076515368, 3893328283], dtype=uint32))
Every instance of RandomState can create a new seed from the current seed with .split_key()
.
state.split_key()
Array([3028232624, 826525938], dtype=uint32)
It can also create multiple seeds from the current seed with .split_keys(n)
. This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.
state.split_keys(2)
Array([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
Array([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)
There is a default RandomState in brainpy.math.random
module: DEFAULT
.
bm.random.DEFAULT
RandomState(key=([1682297581, 3751629511], dtype=uint32))
The inherent random methods like randint()
, rand()
, shuffle()
, etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed()
method.
bm.random.seed(654321)
bm.random.DEFAULT
RandomState(key=([ 0, 654321], dtype=uint32))
In-place updating#
In BrainPy, the transformations (like JIT) usually need to update variables or arrays in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.
For example, here we have a variable a
.
a = bm.Variable(bm.zeros(5))
The ids of the variable and the data stored in the variable are:
id_of_a = id(a)
id_of_data = id(a.value)
assert id_of_a != id_of_data
print('id(a) = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a) = 2781947736704
id(a.value) = 2781965742144
In-place update (here we use [:]
) does not change the pointer refered to the variable but changes its data:
a[:] = 1.
print('id(a) = ', id(a))
print('id(a.value) = ', id(a.value))
id(a) = 2781947736704
id(a.value) = 2781965752128
print('(id(a) == id_of_a) =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a) = True
(id(a.value) == id_of_data) = False
However, once you do not use in-place operators to assign data, the id that the variable a
refers to will change. This will cause serious errors when using transformations in BrainPy.
a = 10.
print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) = 2781946941520
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.Array``.
Here, we list several commonly used in-place operators.
v = bm.Variable(bm.arange(10))
old_id = id(v)
def check_no_change(new_v):
assert id(new_v) == old_id, 'Variable has been changed.'
1. Indexing and slicing#
Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.
Indexing: v[i] = a
or v[(1, 3)] = c
(index multiple values)
v[0] = 1
check_no_change(v)
Slicing: v[i:j] = b
v[1: 2] = 1
check_no_change(v)
Slicing all values: v[:] = d
, v[...] = e
v[:] = 0
check_no_change(v)
v[...] = bm.arange(10)
check_no_change(v)
2. Augmented assignment#
All augmented assignment are in-place operations, which include
add:
+=
subtract:
-=
divide:
/=
multiply:
*=
floor divide:
//=
modulo:
%=
power:
**=
and:
&=
or:
|=
xor:
^=
left shift:
<<=
right shift:
>>=
v += 1
check_no_change(v)
v *= 2
check_no_change(v)
v |= bm.random.randint(0, 2, 10)
check_no_change(v)
v **= 2
check_no_change(v)
v >>= 2
check_no_change(v)
3. .value
assignment#
Another way to in-place update a variable is to assign new data to .value
. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.
v.value = bm.arange(10)
check_no_change(v)
try:
v.value = bm.asarray(1.)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
v.value = bm.random.random(10)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
4. .update()
method#
Actually, the .value
assignment is the same operation as the .update()
method. Users who want a safe assignment can choose this method too.
v.update(bm.random.randint(0, 20, size=10))
try:
v.update(bm.asarray(1.))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
v.update(bm.random.random(10))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
Control Flows for JIT compilation#
Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.
Python has two types of control structures:
Selection: used for decisions and branching.
Repetition: used for looping, i.e., repeating a piece of code multiple times.
In this section, we are going to talk about how to build effective control flows in the context of JIT compilation.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.0'
1. Selection#
In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:
if-else
nested if
if-elif-else
Non-Variable
-based control statements#
Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,
class OddEven(bp.BrainPyObject):
def __init__(self, type_=1):
super(OddEven, self).__init__()
self.type_ = type_
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.type_ == 1:
self.a += 1
elif self.type_ == 2:
self.a -= 1
else:
raise ValueError(f'Unknown type: {self.type_}')
return self.a
In the above example, the target statement in if (statement)
syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.
model = bm.jit(OddEven(type_=1))
model()
Variable([1.], dtype=float32)
model = bm.jit(OddEven(type_=2))
model()
Variable([-1.], dtype=float32)
try:
model = bm.jit(OddEven(type_=3))
model()
except ValueError as e:
print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable
-based control statements#
However, if the statement
target in a if ... else ...
syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.
class OddEvenCauseError(bp.BrainPyObject):
def __init__(self):
super(OddEvenCauseError, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.rand < 0.5: self.a += 1
else: self.a -= 1
return self.a
wrong_model = bm.jit(OddEvenCauseError())
try:
wrong_model()
except Exception as e:
print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: This problem may be caused by several ways:
1. Your if-else conditional statement relies on instances of brainpy.math.Variable.
2. Your if-else conditional statement relies on functional arguments which do not set in "static_argnames" when applying JIT compilation. More details please see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
3. The static variables which set in the "static_argnames" are provided as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].
To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):
brainpy.math.where: return element-wise conditional comparison results.
brainpy.math.ifelse: Conditional statements of
if-else
, orif-elif-else
, … for a scalar-typed value.
brainpy.math.where
#
where(condition, x, y)
function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.
a = 1.
bm.where(a < 0, 0., 1.)
Array(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
Array([0., 0., 1., 1., 0.], dtype=float32, weak_type=True)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
Array([[0., 0., 1.],
[0., 0., 1.],
[0., 0., 0.]], dtype=float32, weak_type=True)
For the above example, we can rewrite it by using where
syntax as:
class OddEvenWhere(bp.BrainPyObject):
def __init__(self):
super(OddEvenWhere, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.where(self.rand < 0.5, 1., -1.)
return self.a
model = bm.jit(OddEvenWhere())
model()
Variable([1.], dtype=float32)
brainpy.math.ifelse
#
Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.
In its simplest case, brainpy.math.ifelse(condition, branches, operands, dyn_vars=None)
is equivalent to:
def ifelse(condition, branches, operands, dyn_vars=None):
true_fun, false_fun = branches
if condition:
return true_fun(operands)
else:
return false_fun(operands)
Based on this function, we can rewrite the above example by using cond
syntax as:
class OddEvenCond(bp.BrainPyObject):
def __init__(self):
super(OddEvenCond, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.ifelse(self.rand[0] < 0.5,
[lambda _: 1., lambda _: -1.])
return self.a
model = bm.jit(OddEvenCond())
model()
Variable([-1.], dtype=float32)
If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands, dyn_vars=None)
can also help you accomplish this easily. Actually, multiple branching case is equivalent to:
def ifelse(conditions, branches, operands, dyn_vars=None):
pred1, pred2, ... = conditions
func1, func2, ..., funcN = branches
if pred1:
return func1(operands)
elif pred2:
return func2(operands)
...
else:
return funcN(operands)
For example, if you have the following code:
def f(a):
if a > 10:
return 1.
elif a > 5:
return 2.
elif a > 0:
return 3.
elif a > -5:
return 4.
else:
return 5.
It can be expressed as:
def f(a):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[1., 2., 3., 4., 5.])
f(11.)
Array(1., dtype=float32, weak_type=True)
f(6.)
Array(2., dtype=float32, weak_type=True)
f(1.)
Array(3., dtype=float32, weak_type=True)
f(-4.)
Array(4., dtype=float32, weak_type=True)
f(-6.)
Array(5., dtype=float32, weak_type=True)
A more complex example is:
def f2(a, x):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[lambda x: x*2,
2.,
lambda x: x**2 -1,
lambda x: x - 4.,
5.],
operands=x)
f2(11, 1.)
Array(2., dtype=float32, weak_type=True)
f2(6, 1.)
Array(2., dtype=float32, weak_type=True)
f2(1, 1.)
Array(0., dtype=float32, weak_type=True)
f2(-4, 1.)
Array(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
Array(5., dtype=float32, weak_type=True)
If instances of brainpy.math.Variable
are used in branching functions, you can declare them in the dyn_vars
argument.
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])
print('a:', a)
print('b:', b)
a: Variable([1., 1.], dtype=float32)
b: Variable([0., 0.], dtype=float32)
2. Repetition#
A repetition statement is used to repeat a group(block) of programming instructions.
In Python, we generally have two loops/repetitive statements:
for loop: Execute a set of statements once for each item in a sequence.
while loop: Execute a block of statements repeatedly until a given condition is satisfied.
Pythonic loop syntax#
Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,
class LoopSimple(bp.BrainPyObject):
def __init__(self):
super(LoopSimple, self).__init__()
rng = bm.random.RandomState(123)
self.seq = bm.Variable(rng.random(1000))
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
for s in self.seq:
self.res += s
return self.res.value
import time
def measure_time(f, return_res=False, verbose=True):
t0 = time.time()
r = f()
t1 = time.time()
if verbose:
print(f'Result: {r}, Time: {t1 - t0}')
return r if return_res else None
model = bm.jit(LoopSimple())
# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 0.9443738460540771
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0
When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.
JAX has provided several important loop syntax, including:
BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable
. Specifically, they are:
brainpy.math.for_loop
brainpy.math.while_loop
In this section, we only talk about how to use our provided loop functions.
brainpy.math.for_loop()
#
brainpy.math.make_loop()
is used to generate a for-loop function when you use Variable
.
Suppose that you are using several JaxArrays (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, dyn_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' are updated in 'body_fun()'
results = body_fun(x)
ys.append(results)
return ys
In BrainPy, you can define this logic using brainpy.math.for_loop()
:
import brainpy.math
hist_of_out_vars = brainpy.math.for_loop(body_fun, dyn_vars, operands)
For the above example, we can rewrite it by using brainpy.math.for_loop
as:
class LoopStruct(bp.BrainPyObject):
def __init__(self):
super(LoopStruct, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
def add(s):
self.res += s
return self.res.value
return bm.for_loop(body_fun=add, dyn_vars=[self.res], operands=self.seq)
model = bm.jit(LoopStruct())
r = measure_time(model, verbose=False, return_res=True)
r.shape
(1000, 1)
In essence, body_fun
defines the one-step updating rule of how variables are updated. All returns of body_fun
will be gathered as the history values.
dyn_vars
defines all dynamical variables used in the body_fun
.
operands
specified the inputs of the body_fun
. It will be looped over the fist axis.
brainpy.math.while_loop()
#
brainpy.math.while_loop()
is used to generate a while-loop function when you use Varible
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.while_loop()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which receives the old values at the latest step and returns the updated values at the current step:
while cond_fun(x):
x = body_fun(x)
Note the difference between brainpy.math.for_loop
and brainpy.math.while_loop
:
The returns of
brainpy.math.for_loop
are the values to be gathered as the history values. While the returns ofbrainpy.math.while_loop
should be the same shape and type with its inputs, because they are represented as the updated values.brainpy.math.for_loop
can receive anything without explicit requirements of returns. But,brainpy.math.while_loop
should return what it receives.
A concreate example of brainpy.math.while_loop
is as the follows:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f():
return i[0] < 10
def body_f():
i.value += 1.
counter.value += i
bm.while_loop(body_f, cond_f, dyn_vars=[i, counter], operands=())
()
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)
Or, similarly,
i = bm.Variable(bm.zeros(1))
def cond_f(counter):
return i[0] < 10
def body_f(counter):
i.value += 1.
return counter + i[0]
bm.while_loop(body_f, cond_f, dyn_vars=[i], operands=(1., ))
(Array(56., dtype=float32),)
Random Number Generation for JIT Compilation#
Although brainpy.math.random
is designed to be seamlessly compatible with numpy.random
, there are still some differences under the context of JIT compilation.
In this section, we are going to talk about how to program a JIT-compatible code with brainpy.math.random
.
import brainpy as bp
import brainpy.math as bm
import numpy as np
# bm.set_platform('cpu')
bp.__version__
'2.3.0'
Using bm.random
outside functions to JIT#
Using brainpy.math.random
outside of functions to JIT is the same as using numpy.random
.
This usage corresponds to the cases that generating random data for further processing. For example,
bm.random.rand(10)
Array([0.7222161 , 0.2043277 , 0.59838593, 0.255252 , 0.14954388,
0.05150986, 0.214692 , 0.03857851, 0.81150043, 0.4669956 ], dtype=float32)
np.random.rand(10)
array([0.60626677, 0.69464463, 0.81361761, 0.1583908 , 0.50378113,
0.17677626, 0.7507633 , 0.75699064, 0.33320096, 0.38958635])
When you are using API functions in brainpy.math.random
, actually you are calling functions in a default RandomState
. Similarly, numpy.random
also has a default RandomState
. Calling a random function in numpy.random
module corresponds to calling the random function in this default NumPy RandomState
.
bm.random.DEFAULT
RandomState(key=([3014124236, 2009892527], dtype=uint32))
Using bm.random
inside a function to JIT#
If you are using random sampling in a JIT function, there are things you need to pay attention to. Otherwise, the error is likely to raise.
As I have stated above, brainpy.math.random
functions are using the default RandomState
. A RandomState
is an instance of brainpy Variable, denoting that it has values to change after calling any its built-in random function. What’s changing is the key
of a RandomState
. For instance,
bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([ 873106783, 4065854088], dtype=uint32))
bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([3526960574, 230845945], dtype=uint32))
Therefore, if you do not specify this DEFAULT RandomState
you are using, repeatedly calling random functions in brainpy.math.random
module will not get what you want, because its key
cannot be updated. For instance,
@bm.jit
def get_data():
return bm.random.random(2)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)
A correct way is explicitly declaring you are using this DEFAULT variable in the JIT transformation.
bm.random.seed()
from functools import partial
@partial(bm.jit, dyn_vars=(bm.random.DEFAULT, ))
def get_data_v2():
return bm.random.random(2)
get_data_v2()
Array([0.38541543, 0.5843446 ], dtype=float32)
get_data_v2()
Array([0.85543776, 0.36957836], dtype=float32)
Or, declare the function as a BrainPyObject
, then use jit()
.
@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def get_data_v3():
return bm.random.random(2)
get_data_v3()
Array([0.31096482, 0.7970413 ], dtype=float32)
get_data_v3()
Array([0.26830554, 0.15947664], dtype=float32)
Using RandomState
for objects to JIT#
Another way I recommend is using instances of RandomState
for objects to JIT. For example, you can initialize a RandomState
in the __init__()
function, then using the initialized RandomState
anywhere.
class MyOb(bp.BrainPyObject):
def __init__(self):
super().__init__()
self.rng = bm.random.RandomState(123)
def __call__(self):
size = (50, 100)
u = self.rng.random(size)
v = self.rng.uniform(size=size)
z = bm.sqrt(-2 * bm.log(u)) * bm.cos(2 * bm.pi * v)
return z
ob = bm.jit(MyOb())
ob()
Array([[ 1.3595979 , -1.3462192 , 0.7149456 , ..., 1.4283268 ,
-1.1362855 , -0.18378317],
[-0.26401126, -1.6798397 , -0.8422355 , ..., 1.0795223 ,
0.41247413, -0.955116 ],
[ 0.6234829 , -0.44811824, -0.03835859, ..., -2.5203867 ,
-0.02713326, 1.6490041 ],
...,
[-0.9861029 , 0.36676335, -0.31499916, ..., 1.526808 ,
-0.7946268 , -0.86713606],
[-1.7008592 , -0.05957834, -0.5677447 , ..., -0.04765594,
0.574145 , -0.11830498],
[-0.22663854, -1.8517947 , -1.3546717 , ..., 1.2332705 ,
-0.79247886, -1.9352005 ]], dtype=float32)
Note that any Variable
instance which can be directly accessed by self.
is able to be automatically found by brainpy’s JIT transformation functions.
Therefore, in this case, we do not need to pass the rng
into the dyn_vars
in bm.jit()
function.
Model Building#
Using Built-in Models#
BrainPy enables modularity programming and easy model debugging. To build a complex brain dynamics model, you just need to group its building blocks. In this section, we are going to talk about what building blocks we provide, and how to use these building blocks.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
Initializing a neuron model#
All neuron models implemented in brainpy are subclasses of brainpy.dyn.NeuGroup
. The initialization of a neuron model just needs to provide the geometry size of neurons in a population group.
hh = bp.neurons.HH(size=1) # only 1 neuron
hh = bp.neurons.HH(size=10) # 10 neurons in a group
hh = bp.neurons.HH(size=(10, 10)) # a grid of (10, 10) neurons in a group
hh = bp.neurons.HH(size=(5, 4, 2)) # a column of (5, 4, 2) neurons in a group
Generally speaking, there are two types of arguments can be set by users:
parameters: the model parameters, like
gNa
refers to the maximum conductance of sodium channel in thebrainpy.dyn.HH
model.variables: the model variables, like
V
refers to the membrane potential of a neuron model.
In default, model parameters are homogeneous, which are just scalar values.
hh = bp.neurons.HH(5) # there are five neurons in this group
hh.gNa
120.0
However, neuron models support heterogeneous parameters when performing computations in a neuron group. One can initialize heterogeneous parameters by several ways.
1. Array
Users can directly provide an array as the parameter.
hh = bp.neurons.HH(5, gNa=bm.random.uniform(110, 130, size=5))
hh.gNa
Array([127.70759, 125.13152, 112.63894, 127.90401, 129.41827], dtype=float32)
2. Initializer
BrainPy provides wonderful supports on initializations. One can provide an initializer to the parameter to instruct the model initialize heterogeneous parameters.
hh = bp.neurons.HH(5, ENa=bp.init.OneInit(50.))
hh.ENa
Array([50., 50., 50., 50., 50.], dtype=float32)
3. Callable function
You can also directly provide a callable function which receive a shape
argument.
hh = bp.neurons.HH(5, ENa=lambda shape: bm.random.uniform(40, 60, shape))
hh.ENa
Array([57.047512, 53.037655, 59.74895 , 50.8206 , 44.256607], dtype=float32)
Here, let’s see how the heterogeneous parameters influence our model simulation.
# we create 3 neurons in a group. Each neuron has a unique "gNa"
model = bp.neurons.HH(3, gNa=bp.init.Uniform(min_val=100, max_val=140))
runner = bp.dyn.DSRunner(model, monitors=['V'], inputs=['input', 5.])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 2], show=True)

Similarly, the setting of the initial values of a variable can also be realized through the above three ways: Array, Initializer, and Callable function. For example,
hh = bp.neurons.HH(
3,
V_initializer=bp.init.Uniform(-80., -60.), # Initializer
m_initializer=lambda shape: bm.random.random(shape), # function
h_initializer=bm.random.random(3), # Array
)
print('V: ', hh.V)
print('m: ', hh.m)
print('h: ', hh.h)
V: Variable([-72.55422 , -61.628696, -71.0226 ], dtype=float32)
m: Variable([0.7881355 , 0.40693295, 0.7243513 ], dtype=float32)
h: Variable([0.47316658, 0.15884387, 0.6759169 ], dtype=float32)
Initializing a synapse model#
Initializing a synapse model needs to provide its pre-synaptic group (pre
), post-synaptic group (post
) and the connection method between them (conn
). The below is an example to create an Exponential synapse model:
neu = bp.neurons.LIF(10)
# here we create a synaptic projection within a population
syn = bp.synapses.Exponential(pre=neu, post=neu, conn=bp.conn.All2All())
BrainPy’s build-in synapse models support heterogeneous synaptic weights and delay steps by using Array, Initializer and Callable function. For example,
syn = bp.synapses.Exponential(neu, neu, bp.conn.FixedProb(prob=0.1),
g_max=bp.init.Uniform(min_val=0.1, max_val=1.),
delay_step=lambda shape: bm.random.randint(10, 30, shape))
syn.g_max
Array([0.5255966 , 0.13250259, 0.49933627, 0.9400071 , 0.56140935,
0.7105977 , 0.89582247, 0.63783807, 0.97180253, 0.2137514 ], dtype=float32)
syn.delay_step
Array([10, 14, 22, 18, 22, 23, 25, 21, 28, 10], dtype=int32)
However, in BrainPy, the built-in synapse models only support homogenous synaptic parameters, like the time constant \(\tau\). Users can customize their synaptic models when they want heterogeneous synaptic parameters.
Similar, the synaptic variables can be initialized heterogeneously by using Array, Initializer, and Callable functions.
Changing model parameters during simulation#
In BrainPy, all the dynamically changed variables (no matter it is changed inside or outside the jitted function) should be marked as brainpy.math.Variable
. BrainPy’s built-in models also support modifying model parameters during simulation.
For example, if you want to fix the gNa
in the first 100 ms simulation, and then try to decrease its value in the following simulations. In this case, we can provide the gNa
as an instance of brainpy.math.Variable
when initializing the model.
hh = bp.neurons.HH(5, gNa=bm.Variable(bm.asarray([120.])))
runner = bp.dyn.DSRunner(hh, monitors=['V'], inputs=['input', 5.])
# the first running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

# change the gNa first
hh.gNa[:] = 100.
# the second running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

Examples of using built-in models#
Here we show users how to simulate a famous neuron models: The Morris-Lecar neuron model, which is a two-dimensional “reduced” excitation model applicable to systems having two non-inactivating voltage-sensitive conductances.
group = bp.neurons.MorrisLecar(1)
Then users can utilize various tools provided by BrainPy to easily simulate the Morris-Lecar neuron model. Here we are not going to dive into details so please read the corresponding tutorials if you want to learn more.
runner = bp.dyn.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.))
runner.run(1000)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)

Next we will also give users an intuitive understanding about building a network composed of different neurons and synapses model. Users can simply initialize these models as below and pass into brainpy.dyn.Network
.
neu1 = bp.neurons.HH(1)
neu2 = bp.neurons.HH(1)
syn1 = bp.synapses.AMPA(neu1, neu2, bp.connect.All2All())
net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
By selecting proper runner, users can simulate the network efficiently and plot the simulation results.
runner = bp.dyn.DSRunner(net,
inputs=(neu1.input, 6.),
monitors=['pre.V', 'post.V', 'syn.g'])
runner.run(150.)
import matplotlib.pyplot as plt
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
plt.legend()
fig.add_subplot(gs[1, 0])
plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
plt.legend()
plt.show()

Building Conductance-based Neuron Models#
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
There are basically two types of neuron models: conductance-based models and simplified models. In conductance-based models, a single neuron can be regarded as a electric circuit, where the membrane is a capacitor, ion channels are conductors, and ion gradients are batteries. The neuronal activity is captured by the current flows through those ion channels. Sometimes there is an external input to this neuron, which can also be included in the equivalent circuit (see the figure below which shows potassium channels, sodium channels and leaky channels).

On the other hand, simplified models do not care about the physiological features of neurons but mainly focus on how to reproduce the exact spike timing. Therefore, they are more simplified and maybe not biologically explicable.
BrainPy provides a large volume of predefined neuron models including conductance-based and simplified models for ease of use. In this section, we will only talk about how to build conductance-based models by ion channels. Users please refer to Customizing Your Neuron Models for more information.
Building an ion channel#
As we have known, ion channels are crucial for conductance-based neuron models. So how do we model an ion channel? Let’s take a look at the potassium channel for instance.

The diagram above shows how a potassium channel is changed to an electric circuit. By this, we have the differential equation:
in which \(c_\mathrm{M}\) is the membrane capacitance, \(\mathrm{d}V_\mathrm{M}\) is the membrane potential, \(E_\mathrm{K}\) is the equilibrium potential of potassium ions, and \(R_\mathrm{K}\) (\(g_\mathrm{K}\)) refers to the resistance (conductance) of the potassium channel. We define currents from inside to outside as the positive direction.
In the equation above, the conductance of potassium channels \(g_\mathrm{K}\) does not remain a constant, but changes according to the membrane potential, by which the channel is categorized as voltage-gated ion channels. If we want to build an ion channel model, we should figure out how the conductance of the ion channel changes with membrane potential.
Fortunately, there has been a lot of work addressing this issue to formulate analytical expressions. For example, the conductance of one typical potassium channel can be written as:
in which \(\bar{g}_\mathrm{K}\) refers to the maximal conductance and \(n\), also named the gating variable, refers to the probability (proportion) of potassium channels to open. \(\phi\) is a parameter showing the effects of temperature. In the differential equation of \(n\), there are two parameters, \(\alpha_n(V)\) and \(\beta_n(V)\), that change with membrane potential:
Now we have learned the mathematical expression of the potassium channel. Next, we try to build this channel in BrainPy.
class IK(bp.Channel):
def __init__(self, size, E=-77., g_max=36., phi=1., method='exp_auto'):
super(IK, self).__init__(size)
self.g_max = g_max
self.E = E
self.phi = phi
self.n = bm.Variable(bm.zeros(size)) # variables should be packed with bm.Variable
self.integral = bp.odeint(self.dn, method=method)
def dn(self, n, t, V):
alpha_n = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta_n = 0.125 * bm.exp(-(V + 65) / 80)
return self.phi * (alpha_n * (1. - n) - beta_n * n)
def update(self, tdi, V):
self.n.value = self.integral(self.n, tdi.t, V, dt=tdi.dt)
def current(self, V):
return self.g_max * self.n ** 4 * (self.E - V)
Note that besides the initialzation and update function, another function named current()
that computes the current flow through this channel must be implemented. Then this potassium channel model can be used as a building block for assembling a conductance-based neuron model.
Building a conductance-based neuron model with ion channels#
Instead of building a conductance-based model from scratch, we can utilize ion channel models as building blocks to assemble a neuron model in a modular and convenient way. Now let’s try to construct a Hodgkin-Hoxley (HH) model (jump to here for the complete mathematical expression of the HH model).
The HH neuron models the cuurent flows of potassium, sodium, and leaky channels. Besides the potassium channel that we implemented, we can import the other channel models from brainpy.dyn.channels
:
from brainpy.dyn.channels import INa_HH1952, IL
# actually the potassium channel we implemented can also be found in this package as 'IK_HH1952'
Then we wrap these three channels into a single neuron model:
class HH(bp.CondNeuGroup):
def __init__(self, size):
super(HH, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
self.IK = IK(size, E=-77., g_max=36.)
self.INa = INa_HH1952(size, E=50., g_max=120.)
self.IL = IL(size, E=-54.39, g_max=0.03)
Here the HH
class should inherit the superclass bp.CondNeuGroup
, which will automatically integrate the current flows by calling the current()
function of each channel model to compute the neuronal activity when running a simulation.
Surprisingly, the model contruction is finished! Users do not need to implement the update function of the neuron model as CondNeuGroup
has its own way to update variables (like the membrane potential V
and spiking sequence spike
) implicitly.
Now let’s run a simulation of this HH model to examine the changes of the inner variables.
First of all, we instantiate a neuron group with 1 HH neuron:
neu = HH(1)
Then we wrap the neuron group into a dynamical-system runner DSRunner
for running a simulation:
runner = bp.DSRunner(
neu,
monitors=['V', 'IK.n', 'INa.p', 'INa.q'],
inputs=('input', 6.) # constant external inputs of 6 mA to all neurons
)
Then we run the simulation and visualize the result:
runner.run(200) # the running time is 200 ms
import matplotlib.pyplot as plt
plt.plot(runner.mon['ts'], runner.mon['V'])
plt.xlabel('t (ms)')
plt.ylabel('V (mV)')
plt.show()

We can also visualize the changes of the gating variables of sodium and potassium channels:
plt.figure(figsize=(6, 2))
plt.plot(runner.mon['ts'], runner.mon['IK.n'], label='n')
plt.plot(runner.mon['ts'], runner.mon['INa.p'], label='m')
plt.plot(runner.mon['ts'], runner.mon['INa.q'], label='h')
plt.xlabel('t (ms)')
plt.legend()
plt.show()

By combining different ion channels, we can get different types of conductance-based neuron models easily and straightforwardly. To see all predifined channel models in BrainPy, please click here.
Building Synapse Models#
In BrainPy, synapse models can be created by several ways. In this section, we will talk about a structural building process with brainpy.dyn.TwoEndConn
, which is used to create models with pre- and post-synaptic neuron groups.
A synapse model is decomposed into several components in brainpy.dyn.TwoEndConn
. In such a way, building a synapse model can follow a modular and composable programming interface.
Fore more details of defining a general synapse model, please refer to brainpy.dyn.SynConn
in Tutorial: Customizing Synapse Models.
import brainpy as bp
# bp.math.set_platform('cpu')
bp.__version__
'2.3.1'
Synaptic models with brainpy.dyn.TwoEndConn
#
In general, brainpy.dyn.TwoEndConn
is used to model synaptic models with the following form:
where each synapse model has its dynamical conductance \(g\), synaptic weight \(g_{\mathrm{max}}\), and
\(I_{\mathrm{post}}\) is the synaptic current onto the post-synaptic neurons,
\(f_{\mathrm{dyn}}\) is the function to compute synaptic dynamics,
\(f_{\mathrm{LTP}}\) is the function for computing synaptic long-term plasticity,
\(f_{\mathrm{STP}}\) is the function for computing synaptic short-term plasticity,
\(f_{\mathrm{out}}\) is the way to output synaptic currents on post-synaptic neurons.
Example 1: Exponential synapse model#
For a exponential synapse model,
where its \(f_{\mathrm{dyn}}\) is defined as equation (1), its \(f_{\mathrm{LTP}}\) and \(f_{\mathrm{STP}}\) is the identity function \(x = f(x)\), \(f_{\mathrm{out}}\) is defined as a conductance-based form with \((V_{\mathrm{post}}(t)-E)\).
Therefore, in BrainPy, we can define this model as the following form:
# a pre-synaptic neuron which generate spike at 1 ms, 11 ms, 21 ms.
pre = bp.neurons.SpikeTimeGroup(1, [1., 11., 21.], [0, 0, 0])
# a post-synaptic integrator which integrate synaptic inputs
post = bp.neurons.LeakyIntegrator(1)
# the synaptic model we want, whose output function is defined with `bp.synouts.COBA`
bp.synapses.Exponential(pre, post, bp.conn.All2All(),
output=bp.synouts.COBA(E=0.))
Exponential(name=Exponential0, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
Similarly, an Exponential synapse model with the current-based output can be defined as:
bp.synapses.Exponential(pre, post, bp.conn.All2All(),
output=bp.synouts.CUBA())
Exponential(name=Exponential1, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
Example 2: NMDA synapse model#
NMDA synapse model is different from other models, since its currents onto post-synaptic groups are regulated by magnesium. Specifically, the net NMDA receptor-mediated synaptic current is given by
where \(g_{\infty}\) represents the fraction of channels that are not blocked by magnesium.
Here \([{Mg}^{2+}]_{o}\) is the extracellular magnesium concentration, usually 1 mM.
In BrainPy, we provide this kind of magnesium-mediated synaptic output with brainpy.synouts.MgBlock
. Therefore, a NMDA synapse can be defined with:
bp.synapses.NMDA(pre, post, bp.conn.All2All(),
output=bp.synouts.MgBlock(E=0., cc_Mg=1.2))
NMDA(name=NMDA0, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
Example 3: Synapse models with short-term plasticity#
Short-term synaptic plasticity is ambitious in synapse dynamics. BrainPy provides brainpy.synplast.STD
for short-term depression and brainpy.synplast.STP
for general short-term plasticity. Short-term synaptic plasticity can be added onto most of synaptic models in BrainPy. For instance, here we define AMPA, GABA, and NMDA synapse models used in (Guoshi Li, et, al., 2017) [1].
[1] Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. “Unified thalamic model generates multiple distinct oscillations with state-dependent entrainment by stimulation.” PLoS computational biology 13.10 (2017): e1005797.
# AMPA synapse model with STD
bp.synapses.AMPA(pre, post, bp.conn.FixedProb(0.3),
stp=bp.synplast.STD(tau=700, U=0.07),
output=bp.synouts.COBA(E=0.),
alpha=0.94, beta=0.18, g_max=6e-3)
AMPA(name=AMPA0, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
# GABA synapse model with STD
bp.synapses.GABAa(pre, post, bp.conn.FixedProb(0.3),
stp=bp.synplast.STD(tau=700, U=0.07),
output=bp.synouts.COBA(E=-80),
alpha=10.5, beta=0.166, g_max=3e-3)
GABAa(name=GABAa0, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
# NMDA synapse model with STD
bp.synapses.NMDA(pre, post, bp.conn.FixedProb(0.3),
stp=bp.synplast.STD(tau=700, U=0.07),
output=bp.synouts.MgBlock(E=0., cc_Mg=1.2))
NMDA(name=NMDA1, mode=NonBatchingMode,
pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)),
post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
Example 4: synapse models with long-term plasticity#
TODO.
Building Network Models#
In previous sections, it has been illustrated how to define neuron models by brainpy.dyn.NeuGroup
and synapse models by brainpy.dyn.TwoEndConn
. This section will introduce brainpy.dyn.Network
, which is the base class used to build network models.
In essence, brainpy.dyn.Network is a container, whose function is to compose the individual elements. It is a subclass of a more general class: brainpy.dyn.Container.
In below, we take an excitation-inhibition (E-I) balanced network model as an example to illustrate how to compose the LIF neurons and Exponential synapses defined in previous tutorials to build a network.
import brainpy as bp
# bp.math.set_platform('cpu')
bp.__version__
'2.3.1'
Excitation-Inhibition (E-I) Balanced Network#
The E-I balanced network was first proposed to explain the irregular firing patterns of cortical neurons and comfirmed by experimental data. The network [1] we are going to implement consists of excitatory (E) neurons and inhibitory (I) neurons, the ratio of which is about 4 : 1. The biggest difference between excitatory and inhibitory neurons is the reversal potential - the reversal potential of inhibitory neurons is much lower than that of excitatory neurons. Besides, the membrane time constant of inhibitory neurons is longer than that of excitatory neurons, which indicates that inhibitory neurons have slower dynamics.
[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98.
# BrianPy has some built-in canonical neuron and synapse models
LIF = bp.neurons.LIF
Exponential = bp.synapses.Exponential
Two ways to define network models#
There are several ways to define a Network model.
1. Defining a network as a class#
The first way to define a network model is like follows.
class EINet(bp.Network):
def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
super(EINet, self).__init__(**kwargs)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(num_exc, **pars, method=method)
I = LIF(num_inh, **pars, method=method)
E.V.value = bp.math.random.randn(num_exc) * 2 - 55.
I.V.value = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
w_e = 0.6 # excitatory synaptic weight
w_i = 6.7 # inhibitory synaptic weight
E_pars = dict(output=bp.synouts.COBA(E=0.), g_max=w_e, tau=5.)
I_pars = dict(output=bp.synouts.COBA(E=-80.), g_max=w_i, tau=10.)
# Neurons connect to each other randomly with a connection probability of 2%
self.E2E = Exponential(E, E, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
self.E2I = Exponential(E, I, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
self.I2E = Exponential(I, E, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
self.I2I = Exponential(I, I, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
self.E = E
self.I = I
In an instance of brainpy.dyn.Network
, all self.
accessed elements can be gathered by the .nodes()
function automatically.
EINet(8, 2).nodes(level=-1).subset(bp.DynamicalSystem)
{'EINet0': EINet(),
'Exponential0': Exponential(name=Exponential0, mode=NonBatchingMode,
pre=LIF(name=LIF0, mode=NonBatchingMode, size=(8,)),
post=LIF(name=LIF0, mode=NonBatchingMode, size=(8,))),
'Exponential1': Exponential(name=Exponential1, mode=NonBatchingMode,
pre=LIF(name=LIF0, mode=NonBatchingMode, size=(8,)),
post=LIF(name=LIF1, mode=NonBatchingMode, size=(2,))),
'Exponential2': Exponential(name=Exponential2, mode=NonBatchingMode,
pre=LIF(name=LIF1, mode=NonBatchingMode, size=(2,)),
post=LIF(name=LIF0, mode=NonBatchingMode, size=(8,))),
'Exponential3': Exponential(name=Exponential3, mode=NonBatchingMode,
pre=LIF(name=LIF1, mode=NonBatchingMode, size=(2,)),
post=LIF(name=LIF1, mode=NonBatchingMode, size=(2,))),
'LIF0': LIF(name=LIF0, mode=NonBatchingMode, size=(8,)),
'LIF1': LIF(name=LIF1, mode=NonBatchingMode, size=(2,)),
'COBA2': COBA,
'COBA4': COBA,
'COBA3': COBA,
'COBA5': COBA}
Note in the above EINet
, we do not define the update()
function. This is because any subclass of brainpy.dyn.Network
has a default update function, in which it automatically gathers the elements defined in this network and sequentially runs the update function of each element.
Let’s try to simulate our defined EINet
model.
net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method
runner = bp.DSRunner(net,
monitors=['E.spike', 'I.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
title='Spikes of Inhibitory Neurons', show=True)
Used time None s


2. Instantiating a network directly#
Another way to instantiate a network model is directly pass the elements into the constructor of brainpy.Network
. It receives *args
and **kwargs
arguments.
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(3200, **pars)
I = LIF(800, **pars)
E.V.value = bp.math.random.randn(E.num) * 2 - 55.
I.V.value = bp.math.random.randn(I.num) * 2 - 55.
# synapses
E_pars = dict(output=bp.synouts.COBA(E=0.), g_max=0.6, tau=5.)
I_pars = dict(output=bp.synouts.COBA(E=-80.), g_max=6.7, tau=10.)
E2E = Exponential(E, E, bp.conn.FixedProb(prob=0.02), **E_pars)
E2I = Exponential(E, I, bp.conn.FixedProb(prob=0.02), **E_pars)
I2E = Exponential(I, E, bp.conn.FixedProb(prob=0.02), **I_pars)
I2I = Exponential(I, I, bp.conn.FixedProb(prob=0.02), **I_pars)
# Network
net2 = bp.Network(E2E, E2I, I2E, I2I, exc_group=E, inh_group=I)
All elements are passed as **kwargs
argument can be accessed by the provided keys. This will affect the following dynamics simulation.
net2.exc_group
LIF(name=LIF4, mode=NonBatchingMode, size=(3200,))
net2.inh_group
LIF(name=LIF5, mode=NonBatchingMode, size=(800,))
After construction, the simulation goes the same way:
runner = bp.DSRunner(net2,
monitors=['exc_group.spike', 'inh_group.spike'],
inputs=[('exc_group.input', 20.), ('inh_group.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['exc_group.spike'],
title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['inh_group.spike'],
title='Spikes of Inhibitory Neurons', show=True)
Used time None s


Customizing update function#
If you want to control your updating logic in a network, you can overwrite the updating function update(tdi)
and customize it by yourself.
For the above E/I balanced network model, we can define its update function as:
class EINetV2(bp.Network):
def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
super(EINetV2, self).__init__(**kwargs)
# neurons
self.N = LIF(num_exc + num_inh,
V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
method=method, V_initializer=bp.init.Normal(-55., 2.))
# synapses
self.Esyn = bp.synapses.Exponential(self.N[:num_exc], self.N,
bp.conn.FixedProb(prob=0.02),
output=bp.synouts.COBA(E=0.),
g_max=0.6, tau=5.,
method=method)
self.Isyn = bp.synapses.Exponential(self.N[num_exc:], self.N,
bp.conn.FixedProb(prob=0.02),
output=bp.synouts.COBA(E=-80.),
g_max=6.7, tau=10.,
method=method)
def update(self, tdi):
self.Esyn(tdi)
self.Isyn(tdi)
self.N(tdi)
self.update_local_delays() # IMPORTANT
In the above, we define one population, and create E (excitatory) and I (inhibitory) projections within this population. Then, we first update synapse models by calling self.Esyn(tdi)
and self.Isyn(tdi)
. This operation can ensure that all synapse inputs can be gathered onto neuron models before we update neurons. After updating synapses, we update the state of neurons by calling self.N(tdi)
. Finally, it’s worthy to note that we need to update all delays used in this network through self.update_local_delays()
. This is because delay variables relying on neurons. Once upon neuronal states have been updated, we need to update delays according to these new values of neuronal states.
net = EINetV2(3200, 800)
runner = bp.DSRunner(net, monitors={'spikes': net.N.spike}, inputs=[(net.N.input, 20.)])
runner.run(100.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)

Customizing Your Neuron Models#
The previous section shows all available models users can utilize by simply instantiating the abstract model. In following sections we will dive into details to illustrate how to build a neuron model with brainpy.dyn.NeuGroup
. Neurons are the most basic components in neural dynamics simulation. In BrainPy, brainpy.dyn.NeuGroup
is used for neuron modeling.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.3.1'
brainpy.dyn.NeuGroup
#
Generally, any neuron model can evolve continuously or discontinuously. Discontinuous evolution may be triggered by events, such as the reset of membrane potential. Moreover, it is common in a neural system that a dynamical system has different states, such as the excitable or refractory state in a leaky integrate-and-fire (LIF) model. In this section, we will use two examples to illustrate how to capture these complexities in neuron modeling.
Defining a neuron model in BrainPy is simple. You just need to inherit from brainpy.dyn.NeuGroup
, and satisfy the following two requirements:
Providing the
size
of the neural group in the constructor when initialize a new neural group class.size
can be a integer referring to the number of neurons or a tuple/list of integers referring to the geometry of the neural group in different dimensions. According to the provided groupsize
, NeuroGroup will automatically calculate the total numbernum
of neurons in this group.Creating an
update(tdi)
function. Update function provides the rule how the neuron states are evolved from the current time \(\mathrm{tdi.t}\) to the next time \(\mathrm{tdi.t + tdi.dt}\).
In the following part, a Hodgkin-Huxley (HH) model is used as an example for illustration.
Hodgkin–Huxley Model#
The Hodgkin-Huxley (HH) model is a continuous-time dynamical system. It is one of the most successful mathematical models of a complex biological process that has ever been formulated. Changes of the membrane potential influence the conductance of different channels, elaborately modeling the neural activities in biological systems. Mathematically, the model is given by:
where \(V\) is the membrane potential, \(C_m\) is the membrane capacitance per unit area, \(E_K\) and \(E_{Na}\) are the potassium and sodium reversal potentials, respectively, \(E_l\) is the leak reversal potential, \(\bar{g}_K\) and \(\bar{g}_{Na}\) are the potassium and sodium conductance per unit area, respectively, and \(\bar{g}_l\) is the leak conductance per unit area. Because the potassium and sodium channels are voltage-sensitive, according to the biological experiments, \(m\), \(n\) and \(h\) are used to simulate the activation of the channels. Specially, \(n\) measures the activation of potassium channels, and \(m\) and \(h\) measures the activation and inactivation of sodium channels, respectively. \(\alpha_{x}\) and \(\beta_{x}\) are rate constants for the ion channel x and depend exclusively on the membrane potential.
To implement the HH model, variables should be specified. According to the above equations, the following five state variables change with respect to time:
V
: the membrane potentialm
: the activation of sodium channelsh
: the inactivation of sodium channelsn
: the activation of potassium channelsinput
: the external/synaptic input
Besides, the spiking state and the last spiking time can also be recorded for statistic analysis:
spike
: whether a spike is producedt_last_spike
: the last spiking time
Based on these state variables, the HH model can be implemented as below.
class HH(bp.NeuGroup):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, **kwargs):
# providing the group "size" information
super(HH, self).__init__(size=size, **kwargs)
# initialize parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.gNa = gNa
self.gK = gK
self.gL = gL
self.C = C
self.V_th = V_th
# initialize variables
self.V = bm.Variable(bm.random.randn(self.num) - 70.)
self.m = bm.Variable(0.5 * bm.ones(self.num))
self.h = bm.Variable(0.6 * bm.ones(self.num))
self.n = bm.Variable(0.32 * bm.ones(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
# integral functions
self.int_V = bp.odeint(f=self.dV, method='exp_auto')
self.int_m = bp.odeint(f=self.dm, method='exp_auto')
self.int_h = bp.odeint(f=self.dh, method='exp_auto')
self.int_n = bp.odeint(f=self.dn, method='exp_auto')
def dV(self, V, t, m, h, n, Iext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
return dVdt
def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt
def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt
def dn(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
# compute V, m, h, n
V = self.int_V(self.V, _t, self.m, self.h, self.n, self.input, dt=_dt)
self.h.value = self.int_h(self.h, _t, self.V, dt=_dt)
self.m.value = self.int_m(self.m, _t, self.V, dt=_dt)
self.n.value = self.int_n(self.n, _t, self.V, dt=_dt)
# update the spiking state and the last spiking time
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
# update V
self.V.value = V
# reset the external input
self.input[:] = 0.
When defining the HH model, equation (1) is accomplished by brainpy.odeint as an ODEIntegrator. The details are contained in the Numerical Solvers for ODEs tutorial.
The variables, which will be updated during dynamics simulation, should be packed as brainpy.math.Variable
and thus can be processed by JIT compliers to accelerate simulation.
In the following part, a leaky integrate-and-fire (LIF) model is introduced as another example for illustration.
Leaky Integrate-and-Fire Model#
The LIF model is the classical neuron model which contains a continuous process and a discontinous spike reset operation. Formally, it is given by:
where \(V\) is the membrane potential, \(V_{rest}\) is the rest membrane potential, \(V_{th}\) is the spike threshold, \(\tau_m\) is the time constant, \(\tau_{ref}\) is the refractory time period, and \(I\) is the time-variant synaptic inputs.
The above two equations model the continuous change and the spiking of neurons, respectively. Moreover, it has multiple states: subthreshold
state, and spiking
or refractory
state. The membrane potential \(V\) is integrated according to equation (1) when it is below \(V_{th}\). Once \(V\) reaches the threshold \(V_{th}\), according to equation (2), \(V\) is reaet to \(V_{rest}\), and the neuron enters the refractory period where the membrane potential \(V\) will remain constant in the following \(\tau_{ref}\) ms.
The neuronal variables, like the membrane potential and external input, can be captured by the following two variables:
V
: the membrane potentialinput
: the external/synaptic input
In order to define the different states of a LIF neuron, we define additional variables:
spike
: whether a spike is producedrefractory
: whether the neuron is in the refractory periodt_last_spike
: the last spiking time
Based on these state variables, the LIF model can be implemented as below.
class LIF(bp.NeuGroup):
def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., R=1., tau=10., t_ref=5., **kwargs):
super(LIF, self).__init__(size=size, **kwargs)
# initialize parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.R = R
self.tau = tau
self.t_ref = t_ref
# initialize variables
self.V = bm.Variable(bm.random.randn(self.num) + V_reset)
self.input = bm.Variable(bm.zeros(self.num))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# integral function
self.integral = bp.odeint(f=self.derivative, method='exp_auto')
def derivative(self, V, t, Iext):
dvdt = (-V + self.V_rest + self.R * Iext) / self.tau
return dvdt
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
# Whether the neurons are in the refractory period
refractory = (_t - self.t_last_spike) <= self.t_ref
# compute the membrane potential
V = self.integral(self.V, _t, self.input, dt=_dt)
# computed membrane potential is valid only when the neuron is not in the refractory period
V = bm.where(refractory, self.V, V)
# update the spiking state
spike = self.V_th <= V
self.spike.value = spike
# update the last spiking time
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
# update the membrane potential and reset spiked neurons
self.V.value = bm.where(spike, self.V_reset, V)
# update the refractory state
self.refractory.value = bm.logical_or(refractory, spike)
# reset the external input
self.input[:] = 0.
In above, the discontinous resetting is implemented with brainpy.math.where
operation.
Instantiation and running#
Here, let’s try to instantiate a HH
neuron group:
neu = HH(10)
in which a neural group containing 10 HH neurons is generated.
The details of the model simulation will be expanded in the Runners section. In brief, running any dynamical system instance should be accomplished with a runner, such like brianpy.DSRunner
and brainpy.ReportRunner
. The variables to be monitored and the input crrents to be applied in the simulation can be provided when initializing the runner. The details are accessible in Monitors and Inputs.
runner = bp.DSRunner(
neu,
monitors=['V'],
inputs=('input', 22.) # constant external inputs of 22 mA to all neurons
)
Then the simulation can be performed with a given time period, and the simulation result can be visualized:
runner.run(200) # the running time is 200 ms
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

A LIF neural group can be instantiated and applied in simulation in a similar way:
group = LIF(10)
runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 22.), jit=True)
runner.run(200)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

Customizing Your Synapse Models#
Synaptic computation is the core of brain dynamics programming. This is because in a real project most of the simulation time spends on the computation of synapses. In order to achieve efficient synaptic computation, BrainPy provides many useful supports. Here, we are going to explore the details of these supports.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
Synapse Models in Math#
Before we talk about the implementation of synapses in BrainPy, it’s better to understand the targets (synapse models) we are going to implement. For different illustration purposes, we are going to implement two synapse models: exponential synapse model and AMPA synapse model.
1. The exponential synapse model#
The exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant \(\tau_{decay}\). Its dynamics is given by:
where \(g\) is the synaptic state, \(t^{k}\) is the spike time of the pre-synaptic neuron, and \(D\) is the synaptic delay.
Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form:
where \(E\) is the reversal potential of the synapse, \(V\) is the post-synaptic membrane potential, \(g_{max}\) is the maximum synaptic conductance.
2. The AMPA synapse model#
A classical model of AMPA synapse is to use the Markov process to model ion channel switch. Here \(g\) represents the probability of channel opening, \(1-g\) represents the probability of ion channel closing, and \(\alpha\) and \(\beta\) are the transition probability. Specifically, its formula is given by
where \(\alpha [T]\) denotes the transition probability from state \((1-g)\) to state \((g)\); and \(\beta\) represents the transition probability of the other direction. \(\alpha\) is the binding constant. \(\beta\) is the unbinding constant. \([T]\) is the neurotransmitter concentration, and has the duration of 0.5 ms.
Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
where \(g_{max}\) is the maximum conductance, and \(E\) is the reverse potential.
Synapse Models in Silicon#
The implementation of synapse models is accomplished by brainpy.dyn.TwoEndConn
interface. In this section, we talk about what supports are provided for the implementation of synapse models in silicon.
1. brainpy.dyn.TwoEndConn
#
In BrainPy, brainpy.dyn.SynConn
is used to model two-end synaptic computations.
To define a synapse model, two requirements should be satisfied:
1. Constructor function __init__()
, in which three key arguments are needed.
pre
: the pre-synaptic neural group. It should be an instance ofbrainpy.dyn.NeuGroup
.post
: the post-synaptic neural group. It should be an instance ofbrainpy.dyn.NeuGroup
.conn
(optional): the connection type between these two groups. BrainPy has provided abundant connection types that are described in details in the Synaptic Connections.
2. Update function update(tdi)
describes the updating rule from the current time \(\mathrm{tdi.t}\) to the next time \(\mathrm{tdi.t + tdi.dt}\).
2. Variable delays#
As seen in the above two synapse models, synaptic computations are usually involved with variable delays. A delay time (typically 0.3–0.5 ms) is usually required for a neurotransmitter to be released from a pre-synaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane.
BrainPy provides several kinds of delay variables for users, including:
brainpy.math.LengthDelay
: a delay variable which defines a constant steps for delay.brainpy.math.TimeDelay
: a delay variable which defines a constant time length for delay.
Assume here we need a delay variable which has 1 ms delay. If the numerical integration precision dt
is 0.1 ms, then we can create a brainpy.math.LengthDelay
which has 10 delay time steps.
target_data_to_delay = bm.Variable(bm.zeros(10))
example_delay = bm.LengthDelay(target_data_to_delay,
delay_len=10) # delay 10 steps
example_delay(5) # call the delay data at 5 delay step
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(10) # call the delay data at 10 delay step
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Alternatively, we can create an instance of brainpy.math.TimeDelay
, which use time t
as the index to retrieve the delay data.
t0 = 0.
example_delay = bm.TimeDelay(target_data_to_delay,
delay_len=1.0, t0=t0) # delay 1.0 ms
example_delay(t0 - 1.0) # the delay data at t-1. ms
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(t0 - 0.5) # the delay data at t-0.5 ms
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
3. Synaptic connections#
Synaptic computations usually need to create connection between groups. BrainPy provides many wonderful supports to construct synaptic connections. Simply speaking, brainpy.conn.Connector
can create various data structures you want through the require()
function. Take the random connection brainpy.conn.FixedProb
which will be used in follows as the example,
example_conn = bp.conn.FixedProb(0.2)(pre_size=(5,), post_size=(8, ))
we can require the connection matrix (has the shape of (num_pre, num_post)
:
example_conn.require('conn_mat')
Array([[False, True, False, False, False, False, True, False],
[False, True, True, False, False, True, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, True, False],
[False, False, False, True, True, False, False, False]], dtype=bool)
we can also require the connected indices of pre-synaptic neurons (pre_ids
) and post-synaptic neurons (post_ids
):
example_conn.require('pre_ids', 'post_ids')
(Array([0, 1, 2, 3, 4], dtype=int32), Array([3, 0, 2, 3, 7], dtype=int32))
Or, we can require the connection structure of pre2post
which stores the information how does each pre-synaptic neuron connect to post-synaptic neurons:
example_conn.require('pre2post')
(Array([7, 6, 1, 4, 1], dtype=int32), Array([0, 1, 2, 3, 4, 5], dtype=int32))
Warning
Every require()
function will establish a new connection pattern, and return the data structure users have required. Therefore any two require()
will return different connection pattern, just like the examples above. Please keep in mind to require all the data structure at once if users want a consistent connection pattern.
More details of the connection structures please see the tutorial of Synaptic Connections.
Achieving efficient synaptic computation is difficult#
Synaptic computations usually need to transform the data of the pre-synaptic dimension into the data of the post-synaptic dimension, or the data with the shape of the synapse number. There does not exist a universal computation method that are efficient in all cases. Usually, we need different ways for different connection situations to achieve efficient synaptic computation. In the next two sections, we will talk about how to define efficient synaptic models when your connections are sparse or dense.
Before we start, we need to define some useful helper functions to define and show synapse models. Then, we will highlight the key differences of model definition when using different synaptic connections.
# Basic Model to define the exponential synapse model. This class
# defines the basic parameters, variables, and integral functions.
class BaseExpSyn(bp.SynConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., method='exp_auto'):
super(BaseExpSyn, self).__init__(pre=pre, post=post, conn=conn)
# check whether the pre group has the needed attribute: "spike"
self.check_pre_attrs('spike')
# check whether the post group has the needed attribute: "input" and "V"
self.check_post_attrs('input', 'V')
# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
# use "LengthDelay" to store the spikes of the pre-synaptic neuron group
self.delay_step = int(delay/bm.get_dt())
self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)
# integral function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
# Basic Model to define the AMPA synapse model. This class
# defines the basic parameters, variables, and integral functions.
class BaseAMPASyn(bp.SynConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.42, E=0., alpha=0.98,
beta=0.18, T=0.5, T_duration=0.5, method='exp_auto'):
super(BaseAMPASyn, self).__init__(pre=pre, post=post, conn=conn)
# check whether the pre group has the needed attribute: "spike"
self.check_pre_attrs('spike')
# check whether the post group has the needed attribute: "input" and "V"
self.check_post_attrs('input', 'V')
# parameters
self.delay = delay
self.g_max = g_max
self.E = E
self.alpha = alpha
self.beta = beta
self.T = T
self.T_duration = T_duration
# use "LengthDelay" to store the spikes of the pre-synaptic neuron group
self.delay_step = int(delay/bm.get_dt())
self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)
# store the arrival time of the pre-synaptic spikes
self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)
# integral function
self.integral = bp.odeint(self.derivative, method=method)
def derivative(self, g, t, TT):
dg = self.alpha * TT * (1 - g) - self.beta * g
return dg
# for more details of how to run a simulation please see the tutorials in "Dynamics Simulation"
def show_syn_model(model):
pre = bp.neurons.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
post = bp.neurons.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
syn = model(pre, post, conn=bp.conn.One2One())
net = bp.Network(pre=pre, post=post, syn=syn)
runner = bp.DSRunner(net,
monitors=['pre.V', 'post.V', 'syn.g'],
inputs=['pre.input', 22.])
runner.run(100.)
fig, gs = bp.visualize.get_figure(1, 2, 3, 4)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['syn.g'], legend='syn.g')
fig.add_subplot(gs[0, 1])
bp.visualize.line_plot(runner.mon.ts, runner.mon['pre.V'], legend='pre.V')
bp.visualize.line_plot(runner.mon.ts, runner.mon['post.V'], legend='post.V', show=True)
Computation with Dense Connections#
Matrix-based synaptic computation is straightforward. Especially, when your models are connected densely, using matrix is highly efficient.
conn_mat
#
Assume two neuron groups are connected through a fixed probability of 0.7.
conn = bp.conn.FixedProb(0.7)(pre_size=6, post_size=8)
Then you can create the connection matrix though conn.require("conn_mat")
:
conn.require('conn_mat')
Array([[ True, False, True, True, True, True, True, False],
[False, True, False, True, False, True, True, True],
[ True, True, True, False, False, True, True, False],
[ True, True, True, False, False, True, True, True],
[ True, True, True, True, False, True, True, True],
[ True, True, True, False, True, True, True, True]], dtype=bool)
conn_mat
has the shape of (num_pre, num_post)
. Therefore, transforming the data with the pre-synaptic dimension into the date of the post-synaptic dimension is very easy. You just need make a matrix multiplication: brainpy.math.dot(pre_values, conn_mat)
(\(\mathbb{R}^\mathrm{num\_pre} @ \mathbb{R}^\mathrm{(num\_pre, num\_post)} \to \mathbb{R}^\mathrm{num\_post}\)).
With the synaptic connection of conn_mat
in above, we can define the exponential synapse model as the follows. It’s worthy to note that the evolution of states ouput onto the same post-synaptic neurons in exponential synapses can be superposed. This means we can declare the synapse variables with the shape of post-synaptic group, rather than the number of the total synapses.
class ExpConnMat(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpConnMat, self).__init__(*args, **kwargs)
# connection matrix
self.conn_mat = self.conn.require('conn_mat').astype(float)
# synapse gating variable
# -------
# NOTE: Here the synapse number is the same with
# the post-synaptic neuron number. This is
# different from the AMPA synapse.
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
# pull the delayed pre spikes for computation
delayed_spike = self.pre_spike(self.delay_step)
# push the latest pre spikes into the bottom
self.pre_spike.update(self.pre.spike)
# integrate the synapse state
self.g.value = self.integral(self.g, _t, dt=_dt)
# update synapse states according to the pre spikes
post_sps = bm.dot(delayed_spike.astype(float), self.conn_mat)
self.g += post_sps
# get the post-synaptic current
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpConnMat)

We can also use conn_mat
to define an AMPA synapse model. Note here the shape of the synapse variable \(g\) is (num_pre, num_post)
, rather than self.post.num
in the above exponential synapse model. This is because the synaptic states of AMPA model can not be superposed.
class AMPAConnMat(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPAConnMat, self).__init__(*args, **kwargs)
# connection matrix
self.conn_mat = self.conn.require('conn_mat').astype(float)
# synapse gating variable
# -------
# NOTE: Here the synapse shape is (num_pre, num_post),
# in contrast to the ExpConnMat
self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
# pull the delayed pre spikes for computation
delayed_spike = self.pre_spike(self.delay_step)
# push the latest pre spikes into the bottom
self.pre_spike.update(self.pre.spike)
# get the time of pre spikes arrive at the post synapse
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
# get the neurotransmitter concentration at the current time
TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
# integrate the synapse state
TT = TT.reshape((-1, 1)) * self.conn_mat # NOTE: only keep the concentrations
# on the invalid connections
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
# get the post-synaptic current
g_post = self.g.sum(axis=0)
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAConnMat)

Special connections#
Sometimes, we can define some synapse models with special connection types, such as all-to-all connection, or one-to-one connection. For these special situations, even the connection information can be ignored, i.e., we do not need conn_mat
or other structures anymore.
Assume the pre-synaptic group connects to the post-synaptic group with a all-to-all fashion. Then, exponential synapse model can be defined as,
class ExpAll2All(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpAll2All, self).__init__(*args, **kwargs)
# synapse gating variable
# -------
# The synapse variable has the shape of the post-synaptic group
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.g.value = self.integral(self.g, _t, dt=_dt)
self.g += delayed_spike.sum() # NOTE: HERE is the difference
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpAll2All)

Similarly, the AMPA synapse model can be defined as
class AMPAAll2All(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPAAll2All, self).__init__(*args, **kwargs)
# synapse gating variable
# -------
# The synapse variable has the shape of the post-synaptic group
self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
TT = TT.reshape((-1, 1)) # NOTE: here is the difference
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
g_post = self.g.sum(axis=0) # NOTE: here is also different
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAAll2All)

Actually, the synaptic computation with these special connections can be very efficient! A concrete example please see a decision making spiking model in BrainPy-Examples. This implementation achieve at least four times acceleration comparing to the implementation in other frameworks.
Computation with Sparse Connections#
However, in the real neural system, the neurons are connected sparsely in essence.
Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need \(10^8\) floats to save the synaptic state, and at each update step, you need do computation on \(10^8\) floats. Actually, the number of synapses you really connect is only \(10^7\). See, there is a huge memory waste and computing resource inefficiency. Moreover, at the given time \(\mathrm{\_t}\), the number of pre-synaptic neurons in the spiking state is small on average. This means we have made many useless computations when defining synaptic computations with matrix-based connections (zeros dot connection matrix results in zeros).
Therefore, we need new ways to define synapse models. Specifically, we use vectors to store the connected neuron indices, like the pre_ids
and post_ids
(see Synaptic Connections).
In the below, we assume you have learned the synaptic connection types detailed in the tutorial of Synaptic Connections.
The pre2post
operator#
A notable difference of brain dynamics models from the deep learning is that they are sparse and event-driven. In order to support this significant different kind of computations, BrainPy has built many useful operators. In this section, we talk about a set of operators needed in pre2post
computations.
Note before we have said that exponential synapse model can make computations at the dimension of the post-synaptic group. Therefore, we can directly transform the pre-synaptic data into the data of the post-synaptic shape. brainpy.math.pre2post_event_sum(events, pre2post, post_num, values) can satisfy your requirements. This operator needs the synaptic structure of pre2post
(a tuple contains the post_ids
and idnptr
of pre-synaptic neurons).
If values
is a scalar, pre2post_event_sum
is equivalent to:
post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values
If values
is a vector, pre2post_event_sum
is equivalent to:
post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values[j]
With this operator, exponential synapse model can be defined as:
class ExpSparse(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpSparse, self).__init__(*args, **kwargs)
# connections
self.pre2post = self.conn.require('pre2post')
# synapse variable
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.g.value = self.integral(self.g, _t, dt=_dt)
# NOTE: update synapse states according to the pre spikes
post_sps = bm.pre2post_event_sum(delayed_spike, self.pre2post, self.post.num, 1.)
self.g += post_sps
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpSparse)
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.1-py3.9.egg\brainpy\math\operators\pre_syn_post.py:94: UserWarning: Please use ``brainpylib.event_ops.event_csr_matvec()`` instead.
warnings.warn('Please use ``brainpylib.event_ops.event_csr_matvec()`` instead.', UserWarning)

This model will be very efficient when your synapses are connected sparsely.
The pre2syn
and syn2post
operators#
However, for AMPA synapse model, the pre-synaptic values can not be directly transformed into the post-synaptic dimensional data. Therefore, we need to first change the pre data into the data of the synapse dimension, then transform the synapse-dimensional data into the post-dimensional data.
Therefore, the core problem of synaptic computation is how to convert values among different shape of arrays. Specifically, in the above AMPA synapse model, we have three kinds of array shapes (see the following figure): arrays with the dimension of pre-synaptic group, arrays of the dimension of post-synaptic group, and arrays with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.
Here BrainPy provides two operators brainpy.math.pre2syn(pre_values, pre_ids) and brainpy.math.syn2post(syn_values, post_ids, post_num) to convert vectors among different dimensions.
brainpy.math.pre2syn()
receives two arguments: “pre_values” (the variable of the pre-synaptic dimension) and “pre_ids” (the connected pre-synaptic neuron index).brainpy.math.syn2post()
receives three arguments: “syn_values” (the variable with the synaptic size), “post_ids” (the connected post-synaptic neuron index) and “post_num” (the number of the post-synaptic neurons).
Based on these two operators, we can define the AMPA synapse model as:
class AMPASparse(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPASparse, self).__init__(*args, **kwargs)
# connection matrix
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
# synapse gating variable
# -------
# NOTE: Here the synapse shape is (num_syn,)
self.g = bm.Variable(bm.zeros(len(self.pre_ids)))
def update(self, tdi, x=None):
_t, _dt = tdi.t, tdi.dt
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
# get the time of pre spikes arrive at the post synapse
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
# get the arrival time with the synapse dimension
arrival_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
# get the neurotransmitter concentration at the current time
TT = ((_t - arrival_times) < self.T_duration) * self.T
# integrate the synapse state
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
# get the post-synaptic current
g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPASparse)

We hope this tutorial will help your synapse models be defined efficiently.
Model Simulation#
Simulation with DSRunner
#
@Tianqiu Zhang @Chaoming Wang @Xiaoyu Chen
The convenient simulation interface for dynamical systems in BrainPy is implemented by brainpy.dyn.DSRunner
. It can simulate various levels of models including channels, neurons, synapses and systems. In this tutorial, we will introduce how to use brainpy.dyn.DSRunner
in detail.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.3.1'
Initializing a DSRunner
#
Generally, we can initialize a runner for dynamical systems with the format of:
runner = DSRunner(target=instance_of_dynamical_system,
inputs=inputs_for_target_DynamicalSystem,
monitors=interested_variables_to_monitor,
dyn_vars=dynamical_changed_variables,
jit=enable_jit_or_not,
progress_bar=report_the_running_progress,
numpy_mon_after_run=transform_into_numpy_ndarray
)
In which
target
specifies the model to be simulated. It must an instance of brainpy.DynamicalSystem.inputs
is used to define the input operations for specific variables.It should be the format of
[(target, value, [type, operation])]
, wheretarget
is the input target,value
is the input value,type
is the input type (such as “fix”, “iter”, “func”),operation
is the operation for inputs (such as “+”, “-”, “*”, “/”, “=”). Also, if you want to specify multiple inputs, just give multiple(target, value, [type, operation])
, such as[(target1, value1), (target2, value2)]
.It can also be a function, which is used to manually specify the inputs for the target variables. This input function should receive one argument
tdi
which contains the shared arguments like timet
, time stepdt
, and indexi
.
monitors
is used to define target variables in the model. During the simulation, the history values of the monitored variables will be recorded. It can also to monitor variables by callable functions and it should be adict
. Thekey
should be a string for later retrieval byrunner.mon[key]
. Thevalue
should be a callable function which receives an argument:tdt
.dyn_vars
is used to specify all the dynamically changed variables used in thetarget
model.jit
determines whether to use JIT compilation during the simulation.progress_bar
determines whether to use progress bar to report the running progress or not.numpy_mon_after_run
determines whether to transform the JAX arrays into numpy ndarray or not when the network finishes running.
Running a DSRunner
#
After initialization of the runner, users can call .run()
function to run the simulation. The format of function .run()
is showed as follows:
runner.run(duration=simulation_time_length,
inputs=input_data,
reset_state=whether_reset_the_model_states,
shared_args=shared_arguments_across_different_layers,
progress_bar=report_the_running_progress,
eval_time=evaluate_the_running_time
)
In which
duration
is the simulation time length.inputs
is the input data. Ifinputs_are_batching=True
,inputs
must be a PyTree of data with two dimensions:(num_sample, num_time, ...)
. Otherwise, theinputs
should be a PyTree of data with one dimension:(num_time, ...)
.reset_state
determines whether to reset the model states.shared_args
is shared arguments across different layers. All the layers can access the elements inshared_args
.progress_bar
determines whether to use progress bar to report the running progress or not.eval_time
determines whether to evaluate the running time.
Here we define an E/I balance network as the simulation model.
class EINet(bp.Network):
def __init__(self, scale=1.0, method='exp_auto'):
super(EINet, self).__init__()
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
self.E = bp.neurons.LIF(num_exc, **pars, method=method)
self.I = bp.neurons.LIF(num_inh, **pars, method=method)
# synapses
prob = 0.1
we = 0.6 / scale / (prob / 0.02) ** 2 # excitatory synaptic weight (voltage)
wi = 6.7 / scale / (prob / 0.02) ** 2 # inhibitory synaptic weight
self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob),
output=bp.synouts.COBA(E=0.), g_max=we,
tau=5., method=method)
self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(prob),
output=bp.synouts.COBA(E=0.), g_max=we,
tau=5., method=method)
self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(prob),
output=bp.synouts.COBA(E=-80.), g_max=wi,
tau=10., method=method)
self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(prob),
output=bp.synouts.COBA(E=-80.), g_max=wi,
tau=10., method=method)
Then we will wrap it into DSRunner for dynamic simulation. brainpy.dyn.DSRunner
aims to provide model simulation with an outstanding performance. It takes advantage of the structural loop primitive to lower the model onto the XLA devices.
# instantiate EINet
net = EINet()
# initialize DSRunner
runner = bp.DSRunner(target=net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
# run the simulation
runner.run(duration=1000.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'])

We have run a simple example of using DSRunner
, but there are many advanced usages despite this. Next we will formally introduce two main aspects that will be used frequently in DSRunner
: monitors and inputs.
Monitors in DSRunner#
In BrainPy, any instance of brainpy.dyn.DSRunner
has a built-in monitor. Users can set up a monitor when initializing a runner. There are multiple methods to initialize a monitor. The first method is to initialize a monitor is through a list of strings.
Initialization with a list of strings#
# initialize monitor through a list of strings
runner1 = bp.DSRunner(target=net,
monitors=['E.spike', 'E.V', 'I.spike', 'I.V'], # 4 elements in monitors
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
where all the strings corresponds to the name of the variables in the EI network:
net.E.V, net.E.spike
(Variable([-55.31656 , -58.02285 , -61.898117, ..., -55.487587, -53.33741 ,
-56.158283], dtype=float32),
Variable([False, False, False, ..., False, False, False], dtype=bool))
Once we call the runner with a given time duration, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using .mon.[variable_name]. The default history times .mon.ts will also be generated after the model finishes its running. Let’s see an example.
runner1.run(100.)
bp.visualize.raster_plot(runner1.mon.ts, runner1.mon['E.spike'], show=True)

Initialization with index specification#
The second method is similar to the first one, with the difference that the index specification is added. Index specification means users only monitor the specific neurons and ignore all the other neurons. Sometimes we do not care about all the contents in a variable. We may be only interested in the values at the certain indices. Moreover, for a huge network with a long-time simulation, monitors will consume a large part of RAM. Therefore, monitoring variables only at the selected indices will be more applicable. BrainPy supports monitoring a part of elements in a Variable with the format of tuple like this:
# initialize monitor through a list of strings with index specification
runner2 = bp.DSRunner(target=net,
monitors=[('E.spike', [1, 2, 3]), # monitor values of Variable at index of [1, 2, 3]
'E.V'], # monitor all values of Variable 'V'
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner2.run(100.)
print('The monitor shape of "E.V" is (run length, variable size) = {}'.format(runner2.mon['E.V'].shape))
print('The monitor shape of "E.spike" is (run length, index size) = {}'.format(runner2.mon['E.spike'].shape))
The monitor shape of "E.V" is (run length, variable size) = (1000, 3200)
The monitor shape of "E.spike" is (run length, index size) = (1000, 3)
Explicit monitor target#
The third method is to use a dict with the explicit monitor target. Users can access model instance and get certain variables as monitor target:
# initialize monitor through a dict with the explicit monitor target
runner3 = bp.DSRunner(target=net,
monitors={'spike': net.E.spike, 'V': net.E.V},
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner3.run(100.)
print('The monitor shape of "V" is = {}'.format(runner3.mon['V'].shape))
print('The monitor shape of "spike" is = {}'.format(runner3.mon['spike'].shape))
The monitor shape of "V" is = (1000, 3200)
The monitor shape of "spike" is = (1000, 3200)
Explicit monitor target with index specification#
The fourth method is similar to the third one, with the difference that the index specification is added:
# initialize monitor through a dict with the explicit monitor target
runner4 = bp.DSRunner(target=net,
monitors={'E.spike': (net.E.spike, [1, 2]), # monitor values of Variable at index of [1, 2]
'E.V': net.E.V}, # monitor all values of Variable 'V'
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner4.run(100.)
print('The monitor shape of "E.V" is = {}'.format(runner4.mon['E.V'].shape))
print('The monitor shape of "E.spike" is = {}'.format(runner4.mon['E.spike'].shape))
The monitor shape of "E.V" is = (1000, 3200)
The monitor shape of "E.spike" is = (1000, 2)
In spite of the four methods mentioned above, BrainPy also provides users a convenient parameter to monitor more complicate variables: fun_monitor
. Users can use a function to describe monitor and pass to fun_monitor
. fun_monitor
must be a dict and the key
should be a string for the later retrieval by runner.mon[key]
, the value
should be a callable function which receives an arguments: tdi
. The format of fun_monitor
is shown as below:
fun_monitor = {'key_name': lambda tdi: body_func(tdi)}
Here we monitor a variable that
runner5 = bp.DSRunner(target=net,
monitors={'E-I.spike': lambda tdi: bm.concatenate((net.E.spike, net.I.spike), axis=0)},
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner5.run(100.)
bp.visualize.raster_plot(runner5.mon.ts, runner5.mon['E-I.spike'])

Inputs in DSRunner#
In brain dynamics simulation, various inputs are usually given to different units of the dynamical system. In BrainPy, inputs
can be specified to runners for dynamical systems. The aim of inputs
is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording.
inputs
should have the format like (target, value, [type, operation])
, where
target
is the target variable to inject the input.value
is the input value. It can be a scalar, a tensor, or a iterable object/function.type
is the type of the input value. It support two types of input:fix
anditer
. The first one means that the data is static; the second one denotes the data can be iterable, no matter whether the input value is a tensor or a function. Theiter
type must be explicitly stated.operation
is the input operation on the target variable. It should be set as one of{ + , - , * , / , = }
, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated asval = val + input
.
Users can also give multiple inputs for different target variables, like:
inputs=[(target1, value1, [type1, op1]),
(target2, value2, [type2, op2]),
... ]
Static inputs#
The first example is providing static inputs. The excitation and inhibition neurons all receive the same current intensity:
runner6 = bp.DSRunner(target=net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)], # static inputs
jit=True)
runner6.run(100.)
bp.visualize.raster_plot(runner6.mon.ts, runner6.mon['E.spike'])

Iterable inputs#
The second example is providing iterable inputs. Users need to set type=iter
and pass an iterable object or function into value
:
I, length = bp.inputs.section_input(values=[0, 20., 0],
durations=[100, 1000, 100],
return_length=True,
dt=0.1)
runner7 = bp.DSRunner(target=net,
monitors=['E.spike'],
inputs=[('E.input', I, 'iter'), ('I.input', I, 'iter')], # iterable inputs
jit=True)
runner7.run(length)
bp.visualize.raster_plot(runner7.mon.ts, runner7.mon['E.spike'])

By examples given above, users can easily understand the usage of inputs parameters. Similar to monitors, inputs can also be more complicate as a function form. BrainPy provides fun_inputs
to receive the customized functional inputs created by users.
def set_input(tdi):
net.E.input[:] = 20
net.I.input[:] = 20.
runner8 = bp.DSRunner(target=net,
monitors=['E.spike'],
inputs=set_input, # functional inputs
jit=True)
runner8.run(200.)
bp.visualize.raster_plot(runner8.mon.ts, runner8.mon['E.spike'])

Parallel Simulation for Parameter Exploration#
Parameter exploration and selection is an essential part in brain dynamics modeling. In general, there are two problems for the parameter exploration:
how to run multiple models concurrently?
how to manage device memory allowing multiple models to run concurrently?
First, most of the BrainPy models supports multiple kinds of parallelization, including parallelization of multi-threading and multi-processing on a single machine, and parallelization across multiple devices. In the below, we will illustrate these parallelization APIs one-by-one.
Second, every call of a BrainPy model will consume a fraction of device memory. Therefore, BrainPy provides a API brainpy.math.clear_buffer_memory()
for memory clean.
In the following, we will illustrate how to combine them together to get an efficient parameter exploration for your models.
import brainpy as bp
import brainpy.math as bm
import numpy as np
# bm.set_platform('cpu')
bp.__version__
'2.3.0'
Parallelization across different CPU processors#
Parallelization across multiple CPU processors can be easily achieved with a single line of functional call brainpy.running.cpu_ordered_parallel()
. The following pseudocode demonstrates the usage of this API.
import brainpy as bp
# define your function
def run_model(par):
model = YourModel(par)
runner = bp.DSRunner(model)
runner.run(duration)
return runner.mon
# define all parameter values need to explore
all_params = [...]
# run models in Jupyter
results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)
# run models in python file
if __name__ == '__main__':
results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)
We will use a simple HH neuron model as an example to show this kind of parallelization method. In this example, we use multi-processing technique to test four different current values as input.
First, define your running function with the well-defined input
and output
data.
def hh_spike_num(bg_current): # "input" is the bg_current
import brainpy as bp # needed to reimport packages when
# run the function in Jupyter
model = bp.neurons.HH(1)
runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
runner.run(1000.)
return runner.mon['spike'].sum() # "output" is the spike number
Then, define all your parameter spaces.
current = bm.linspace(1, 10.1, 10) # here only one parameter
Finally, run your model concurrently with the parallelization syntax.
r = bp.running.cpu_ordered_parallel(hh_spike_num, [current], num_process=10)
r
[0, 0, 1, 48, 53, 0, 54, 63, 66, 68]
However, the above usage will accumulate buffer memory in the running device. If your single model occupies too much memory, the out-of-memory error will be raised during the parameter exploration.
A simple way to solve this issue is clear all buffers after each running of the function. For example, before returning your results, call brainpy.math.clear_buffer_memory()
first.
def hh_spike_num2(bg_current): # "input" is the bg_current
import brainpy as bp # needed to reimport packages when
# run the function in Jupyter
bg_current = bp.math.as_jax(bg_current)
model = bp.neurons.HH(1)
runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
runner.run(1000.)
bp.math.clear_buffer_memory()
return runner.mon['spike'].sum() # "output" is the spike number
Note that clear_buffer_memory()
will clear all JAX arrays in the device, therefore, it’s better to give inputs as NumPy arrays, and return outputs as NumPy arrays.
current = np.linspace(1., 10., 10)
r = bp.running.cpu_ordered_parallel(hh_spike_num2, [current], num_process=10)
r
[0, 0, 1, 0, 0, 57, 60, 58, 65, 68]
If you think that the order of the running results does not matter, you can also use cpu_unordered_parallel()
function. This can maximize the running efficiency of all processors, since all workers run with a non-blocking and unordered manner.
Parallelization with jax.vmap
#
The second approach of realizing multi-threading parallelization is the vectorization map of JAX jax.vmap
. jax.vmap
vectorizes functions by compiling the mapped axis as primitive operations. It can avoid the recompilation of models in the same batch, and automatically parallelize the model running on the given machine. Following pseudocode demonstrates how simple of this parallelization approach is.
from jax import vmap
def run_model(par):
model = YourModel(par)
runner = bp.DSRunner(model)
runner.run(duration)
return runner.mon
# define all parameter values need to explore
all_params = [...]
# batch simulation through jax.vmap
r = vmap(run_model)(*all_params)
Note that if you have too many parameters to search, jax.vmap
will consume too much memory. For this time, you can use our wrapped API brainpy.running.jax_vectorize_map()
, which controls the running batch size by num_parallel
parameter. You can set a smaller value of num_parallel
when your device memory is not enough (no matter on the CPU or GPU device).
def hh_spike_num3(bg_current): # "input" is the bg_current
model = bp.neurons.HH(1)
runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current],
numpy_mon_after_run=False)
runner.run(1000.)
return runner.mon['spike'].sum() # "output" is the spike number
current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3)
r
Array([ 0, 0, 0, 0, 0, 45, 60, 63, 66, 68], dtype=int32)
The function throw into the jax_vectorize_map()
can not call clear_buffer_memory()
. Otherwise will raise errors. Instead, uses can set clear_buffer=True/False
using jax_vectorize_map()
. For such kind of usage, all inputs and outputs will be automatically transformed in to NumPy arrays.
current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3, clear_buffer=True)
r
array([ 0, 1, 1, 0, 0, 57, 60, 63, 66, 68])
Parallelization across multiple devices#
BrainPy support parallelization running on multiple devices (e.g., multiple GPU devices or TPU cores) or HPC systems (e.g., supercomputers). Different from the above thread-based and processor-based parallelization methods, in which the same model runs in parallel on the same device, device-based parallelization runs the same model in parallel on multiple devices.
One way to express the multi-device parallelization of BrainPy models is using jax.pmap
instruction. JAX delivers jax.pmap
to express SIMD programs. It provides an interface to run the same model on multiple devices with different parameter values. It usage is analogy to jax.vmap
. Following pseudocode presents an example to run BrainPy models on multiple devices.
from jax import pmap
def run_model(par):
model = YourModel(par)
runner = bp.DSRunner(model)
runner.run(<int>)
return runner.mon
# define all parameter values need to explore
all_params = [...]
# parallel simulation through jax.pmap
r = pmap(run_model)(*all_params)
jax.pmap
has the similar issue to jax.vmap
when you parallelize across many parameters. This time you can use the wrapped function brainpy.running.jax_parallelize_map()
.
If you are using pmap
in you CPU device, you can set the virtual number of the device by calling brainpy.math.set_host_device_count(n)
. Then, you can call jax_parallelize_map()
safely one your CPU platform.
bp.math.set_host_device_count(10) # this should place on the top of the file
current = bm.linspace(1., 10.1, 20)
r = bp.running.jax_parallelize_map(hh_spike_num3, [current], num_parallel=10, clear_buffer=True)
r
array([ 0, 0, 0, 0, 0, 0, 0, 49, 52, 54, 56, 58, 59, 61, 62, 63, 65,
66, 67, 68])
BrainPy also works well with job scheduling systems such as SLURM on a supercomputer center. Therefore, another way to express multi-device parallelization is to employ the classical resource management system. Following script demonstrates an example that submits a batch script to SLURM.
#!/bin/bash
#SBATCH -J <name>
#SBATCH -o <file name>
#SBATCH -p <str>
#SBATCH -n <int>
#SBATCH -N <int>
#SBATCH -c <int>
python your_script.py
Model Training#
This tutorial shows how to train a dynamical system from data or task.
Building Training Models#
In this section, we are going to talk about how to build models for training.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.3.1'
Use built-in models#
brainpy.dyn.DynamicalSystem
provided in BrainPy can be used for model training.
mode
settings#
Some built-in models have implemented the training interface for their training. Users can instantiate these models by providing the parameter mode=brainpy.modes.training
for training model customization.
For example, brainpy.neurons.LIF
is a model commonly used in computational simulation, but it can also be used in training.
# Instantiate a LIF model for simulation
lif = bp.neurons.LIF(1)
lif.mode
NonBatchingMode
# Instantiate a LIF model for training.
# In this mode, the model implement variables and functions
# compatible with BrainPy's training interface.
lif = bp.neurons.LIF(1, mode=bp.modes.training)
lif.mode
TrainingMode
But some build-in models does not support training.
try:
bp.layers.NVAR(1, 1, mode=bp.modes.training)
except Exception as e:
print(type(e), e)
<class 'NotImplementedError'> NVAR does not support TrainingMode. We only support BatchingMode, NonBatchingMode.
The mode
can be used to control the weight types. Let’s take a synaptic model for another example. For a non-trainable dense layer, the weights and bias are Array instances.
l = bp.layers.Dense(3, 4, mode=bm.batching_mode)
l.W
Array([[ 0.13143115, 0.7037631 , -0.50639415, -0.49906093],
[-0.39095506, 0.5210247 , 0.6293488 , 0.7321653 ],
[ 0.2841127 , 0.3818757 , -0.19256772, -0.6708007 ]], dtype=float32)
l = bp.layers.Dense(3, 4, mode=bm.training_mode)
l.W
TrainVar([[-0.36896244, 0.6050412 , -0.53849053, 0.03913487],
[-0.78182685, -0.7611104 , -0.00870763, -0.06463569],
[-0.2160572 , 0.5157468 , 0.09730986, 0.16213563]], dtype=float32)
Moreover, for some recurrent models, e.g., LSTM
or GRU
, the state
can be set to be trainable or not trainable by train_state
argument. When setting train_state=True
for the recurrent instance, a new attribute .state2train will be created.
rnn = bp.layers.RNNCell(1, 3, train_state=True)
rnn.state2train
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [7], in <cell line: 3>()
1 rnn = bp.layers.RNNCell(1, 3, train_state=True)
----> 3 rnn.state2train
AttributeError: 'RNNCell' object has no attribute 'state2train'
Note the difference between the .state2train and the original .state:
.state2train has no batch axis.
When using
node.reset_state()
function, all values in the .state will be filled with .state2train.
rnn.reset_state(batch_size=5)
rnn.state
Naming a node#
For convenience, you can name a layer by specifying the name keyword argument:
bp.layers.Dense(128, 100, name='hidden_layer')
Initializing parameters#
Many models have their parameters. We can set the parameter of a model with the following methods.
Arrays
If an array is provided, this is used unchanged as the parameter variable. For example:
l = bp.layers.Dense(10, 50, W_initializer=bm.random.normal(0, 0.01, size=(10, 50)))
l.W.shape
Callable function
If a callable function (which receives a shape
argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:
def init(shape):
return bm.random.random(shape)
l = bp.layers.Dense(20, 30, W_initializer=init)
l.W.shape
Instance of
brainpy.init.Initializer
If a brainpy.init.Initializer
instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:
l = bp.layers.Dense(20, 30, W_initializer=bp.init.Normal(0.01))
l.W.shape
The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.init for more information).
None parameter
Some types of parameter variables can also be set to None
at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:
l = bp.layers.Dense(20, 100, b_initializer=None)
print(l.b)
Customize your models#
Customizing your training models is simple. You just need to subclass brainpy.DynamicalSystem
, and implement its update()
and reset_state()
functions.
Here, we demonstrate the model customization using two examples. The first is a recurrent layer.
class RecurrentLayer(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(RecurrentLayer, self).__init__()
bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))
# define parameters
self.num_in = num_in
self.num_out = num_out
# define variables
self.state = bm.Variable(bm.zeros(1, num_out), batch_axis=0)
# define weights
self.win = bm.TrainVar(bm.random.normal(0., 1./num_in ** 0.5, size=(num_in, num_out)))
self.wrec = bm.TrainVar(bm.random.normal(0., 1./num_out ** 0.5, size=(num_out, num_out)))
def reset_state(self, batch_size):
# this function defines how to reset the mode states
self.state.value = bm.zeros((batch_size, self.num_out))
def update(self, sha, x):
# this function defined how the model update its state and produce its output
out = bm.dot(x, self.win) + bm.dot(self.state, self.wrec)
self.state.value = bm.tanh(out)
return self.state.value
This simple example illustrates many features essential for a training model. reset_state()
function defines how to reset model states, which will be called at the first time step; update()
function defines how the model states are evolving, which will be called at every time step.
Another example is the dropout layer, which can be useful to demonstrate how to define a model with multiple behaviours.
class Dropout(bp.dyn.DynamicalSystem):
def __init__(self, prob: float, seed: int = None, name: str = None):
super(Dropout, self).__init__(name=name)
bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode))
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)
def update(self, sha, x):
if sha.get('fit', True):
keep_mask = self.rng.bernoulli(self.prob, x.shape)
return bm.where(keep_mask, x / self.prob, 0.)
else:
return x
Here, the model makes different outputs according to the different values of a shared parameter fit
.
You can define your own shared parameters, and then provide their shared parameters when calling the trainer objects (see the following section).
Examples of training models#
In the following, we illustrate several examples to build a trainable neural network model.
Artificial neural networks#
BrainPy provides neural network layers which can be useful to define artificial neural networks.
Here, let’s define a deep RNN model.
class DeepRNN(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_recs, num_out):
super(DeepRNN, self).__init__()
self.l1 = bp.layers.LSTM(num_in, num_recs[0])
self.d1 = bp.layers.Dropout(0.2)
self.l2 = bp.layers.LSTM(num_recs[0], num_recs[1])
self.d2 = bp.layers.Dropout(0.2)
self.l3 = bp.layers.LSTM(num_recs[1], num_recs[2])
self.d3 = bp.layers.Dropout(0.2)
self.l4 = bp.layers.LSTM(num_recs[2], num_recs[3])
self.d4 = bp.layers.Dropout(0.2)
self.lout = bp.layers.Dense(num_recs[3], num_out)
def update(self, sha, x):
x = self.d1(sha, self.l1(sha, x))
x = self.d2(sha, self.l2(sha, x))
x = self.d3(sha, self.l3(sha, x))
x = self.d4(sha, self.l4(sha, x))
return self.lout(sha, x)
with bm.training_environment():
model = DeepRNN(100, 200, 10)
Note here the difference of the model building from PyTorch is that the first argument in update()
function should be the shared parameters sha
(i.e., these parameters are shared across all models, like the time t
, the running index i
, and the model running phase fit
). Then other individual arguments can all be customized by users. The details of the model definition specification can be seen in ????
Moreover, it is worthy to note that this model only defines the one step updating rule of how the model evolves according to the input x
.
Reservoir computing models#
In this example, we define a reservoir computing model called next generation reservoir computing by using the built-in models provided in BrainPy.
class NGRC(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__(mode=bm.batching_mode)
self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5, mode=bm.batching_mode)
self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
In the above model, brainpy.layers.NVAR
is a nonlinear vector autoregression machine, which does not have the training features. Therefore, we define its mode
as batching mode. On the contrary, brainpy.layers.Dense
has the trainable weights for model training.
Spiking Neural Networks#
Building trainable spiking neural networks in BrainPy is also a piece of cake. We provided commonly used spiking models for traditional dynamics simulation. But most of them can be used for training too.
In the following, we provide an implementation of spiking neural networks in (Neftci, Mostafa, & Zenke, 2019) for surrogate gradient learning.
class SNN(bp.dyn.Network):
def __init__(self, num_in, num_rec, num_out):
super(SNN, self).__init__()
# neuron groups
self.i = bp.neurons.InputGroup(num_in)
self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)
# synapse: i->r
self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
output=bp.synouts.CUBA(), tau=10.,
g_max=bp.init.KaimingNormal(scale=20.))
# synapse: r->o
self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
output=bp.synouts.CUBA(), tau=10.,
g_max=bp.init.KaimingNormal(scale=20.))
def update(self, tdi, spike):
self.i2r(tdi, spike)
self.r2o(tdi)
self.r(tdi)
self.o(tdi)
return self.o.V.value
with bm.training_environment():
snn = SNN(10, 100, 2)
Note here the mode in all models are specified as brainpy.modes.TrainingMode
.
Training with Offline Algorithms#
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import matplotlib.pyplot as plt
bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
BrainPy provides many offline training algorithms can help users train models such as reservoir computing models.
Train a reservoir model#
Here, we train an echo-state machine to predict chaotic dynamics. This example is used to illustrate how to use brainpy.train.OfflineTrainer
.
We first get the training dataset.
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
# Training data must have batch size, here the batch is 1
return res.reshape((1, ) + res.shape)
dt = 0.01
t_warmup, t_train, t_test = 5., 100., 50. # ms
num_warmup, num_train, num_test = int(t_warmup/dt), int(t_train/dt), int(t_test/dt)
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
dt=dt,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 5)
X_train = get_subset(lorenz_series, num_warmup - 5, num_warmup + num_train - 5)
X_test = get_subset(lorenz_series,
num_warmup + num_train - 5,
num_warmup + num_train + num_test - 5)
# out target data is the activity ahead of 5 time steps
Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
Y_test = get_subset(lorenz_series,
num_warmup + num_train,
num_warmup + num_train + num_test)
Then, we try to build an echo-state machine to predict the chaotic dynamics ahead of five time steps.
class ESN(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in, num_hidden,
Win_initializer=bp.init.Uniform(-0.1, 0.1),
Wrec_initializer=bp.init.Normal(scale=0.1),
in_connectivity=0.02,
rec_connectivity=0.02,
comp_type='dense')
self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal(),
mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
model = ESN(3, 100, 3)
Here, we use ridge regression as the training algorithm to train the chaotic model.
trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.RidgeRegression(1e-7), dt=dt)
# first warmup the reservoir
_ = trainer.predict(X_warmup)
# then fit the reservoir model
_ = trainer.fit([X_train, Y_train])
def plot_lorenz(ground_truth, predictions):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(121, projection='3d')
ax.set_title("Generated attractor")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax.grid(False)
ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Real attractor")
ax2.grid(False)
ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
plt.show()
# finally, predict the model with the test data
outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS: 0.041717741900418666

Switch different training algorithms#
brainpy.train.OfflineTrainer
supports easy switch of training algorithms. You just need provide the fit_method
argument when instantiating an offline trainer.
Many offline algorithms, like linear regression, ridge regression, and Lasso regression, have been provided as the build-in models.
model = ESN(3, 100, 3)
model.reset_state(1)
trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.LinearRegression())
_ = trainer.predict(X_warmup)
_ = trainer.fit([X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())

Customize your training algorithms#
brainpy.train.OfflineTrainer
also supports to train models with your customized training algorithms.
Specifically, the customization of an offline algorithm should follow the interface of brainpy.algorithms.OfflineAlgorithm
, in which users specify how the model parameters are calculated according to the input, prediction, and target data.
For instance, here we use the Lasso
model provided in scikit-learn package to define an offline training algorithm.
from sklearn.linear_model import Lasso
class LassoAlgorithm(bp.algorithms.OfflineAlgorithm):
def __init__(self, alpha=1., max_iter=int(1e4)):
super(LassoAlgorithm, self).__init__()
self.model = Lasso(alpha=alpha, max_iter=max_iter)
def __call__(self, identifier, y, x, outs=None):
x = bm.as_numpy(x[0])
y = bm.as_numpy(y[0])
x_new = self.model.fit(x, y).coef_.T
return bm.expand_dims(bm.asarray(x_new), 1)
model = ESN(3, 100, 3)
model.reset_state(1)
# note here scikit-learn algorithms does not support JAX jit,
# therefore the "jit" of the "fit" phase is set to be False.
trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.LinearRegression(),
jit={'fit': False})
_ = trainer.predict(X_warmup)
_ = trainer.fit([X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())

Training with Online Algorithms#
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
import brainpy_datasets as bd
bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
Online training algorithms, such as FORCE learning, have played vital roles in brain modeling. BrainPy provides brainpy.train.OnlineTrainer
for model training with online algorithms.
Train a reservoir model#
Here, we are going to use brainpy.train.OnlineTrainer
to train a next generation reservoir computing model (NGRC) to predict chaotic dynamics.
We first get the training dataset.
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
# Training data must have batch size, here the batch is 1
return res.reshape((1, ) + res.shape)
dt = 0.01
t_warmup, t_train, t_test = 5., 100., 50. # ms
num_warmup, num_train, num_test = int(t_warmup/dt), int(t_train/dt), int(t_test/dt)
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
dt=dt,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 5)
X_train = get_subset(lorenz_series, num_warmup - 5, num_warmup + num_train - 5)
X_test = get_subset(lorenz_series,
num_warmup + num_train - 5,
num_warmup + num_train + num_test - 5)
# out target data is the activity ahead of 5 time steps
Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
Y_test = get_subset(lorenz_series,
num_warmup + num_train,
num_warmup + num_train + num_test)
Then, we try to build a NGRC model to predict the chaotic dynamics ahead of five time steps.
class NGRC(bp.dyn.DynamicalSystem):
def __init__(self, num_in):
super(NGRC, self).__init__()
self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True)
self.o = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None, mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
model = NGRC(3)
model.reset_state(1)
Here, we use ridge regression as the training algorithm to train the chaotic model.
trainer = bp.train.OnlineTrainer(model, fit_method=bp.algorithms.RLS(), dt=dt)
# first warmup the reservoir
_ = trainer.predict(X_warmup)
# then fit the reservoir model
_ = trainer.fit([X_train, Y_train])
def plot_lorenz(ground_truth, predictions):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(121, projection='3d')
ax.set_title("Generated attractor")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax.grid(False)
ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Real attractor")
ax2.grid(False)
ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
plt.show()
# finally, predict the model with the test data
outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS: 0.0007826608198954596

Training with Back-propagation Algorithms#
Back-propagation (BP) trainings have become foundations in machine learning algorithms. In this section, we are going to talk about how to train models with BP.
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import numpy as np
bm.set_mode(bm.training_mode)
bm.set_platform('cpu')
bp.__version__
'2.3.1'
Here, we train two kinds of models to classify MNIST dataset. The first is ANN models commonly used in deep neural networks. The second is SNN models.
Train a ANN model#
We first build a three layer ANN model:
i >> r >> o
where the recurrent layer r
is a LSTM cell, the output o
is a linear readout.
class ANNModel(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_rec, num_out):
super(ANNModel, self).__init__()
self.rec = bp.layers.LSTMCell(num_in, num_rec)
self.out = bp.layers.Dense(num_rec, num_out)
def update(self, sha, x):
x = self.rec(sha, x)
x = self.out(sha, x)
return x
Before training this model, we get and clean the data we want.
root = r"D:\data"
train_dataset = bd.vision.FashionMNIST(root, split='train', download=True)
test_dataset = bd.vision.FashionMNIST(root, split='test', download=True)
def get_data(dataset, batch_size=256):
rng = bm.random.clone_rng()
def data_generator():
X = bm.array(dataset.data, dtype=bm.float_) / 255
Y = bm.array(dataset.targets, dtype=bm.float_)
key = rng.split_key()
rng.shuffle(X, key=key)
rng.shuffle(Y, key=key)
for i in range(0, len(dataset), batch_size):
yield X[i: i + batch_size], Y[i: i + batch_size]
return data_generator
Then, we start to train our defined ANN model with brainpy.train.BPTT
training interface.
# model
model = ANNModel(28, 100, 10)
# loss function
def loss_fun(predicts, targets):
predicts = bm.max(predicts, axis=1)
loss = bp.losses.cross_entropy_loss(predicts, targets)
acc = bm.mean(predicts.argmax(axis=-1) == targets)
return loss, {'acc': acc}
# optimizer
optimizer=bp.optim.Adam(lr=1e-3)
# trainer
trainer = bp.train.BPTT(model,
loss_fun=loss_fun,
loss_has_aux=True,
optimizer=optimizer)
trainer.fit(train_data=get_data(train_dataset, 256),
test_data=get_data(test_dataset, 512),
num_epoch=10)
Train 0 epoch, use 11.7911 s, loss 0.9686446785926819, acc 0.657629668712616
Test 0 epoch, use 1.1449 s, loss 0.6075307726860046, acc 0.7804630398750305
Train 1 epoch, use 9.1454 s, loss 0.5276323556900024, acc 0.8128490447998047
Test 1 epoch, use 0.2927 s, loss 0.5302323698997498, acc 0.8131089210510254
Train 2 epoch, use 9.0979 s, loss 0.4588756263256073, acc 0.8355662226676941
Test 2 epoch, use 0.3037 s, loss 0.4678855538368225, acc 0.8310604095458984
Train 3 epoch, use 9.1892 s, loss 0.42316296696662903, acc 0.8461214303970337
Test 3 epoch, use 0.2911 s, loss 0.45249971747398376, acc 0.8364487886428833
Train 4 epoch, use 9.0769 s, loss 0.3961907625198364, acc 0.8553800582885742
Test 4 epoch, use 0.2947 s, loss 0.4217829704284668, acc 0.8458294868469238
Train 5 epoch, use 8.9839 s, loss 0.3784363567829132, acc 0.8621509075164795
Test 5 epoch, use 0.3015 s, loss 0.41539546847343445, acc 0.8533375859260559
Train 6 epoch, use 8.9756 s, loss 0.362664133310318, acc 0.8676861524581909
Test 6 epoch, use 0.3095 s, loss 0.3904822766780853, acc 0.8592026829719543
Train 7 epoch, use 9.0243 s, loss 0.34826964139938354, acc 0.8724401593208313
Test 7 epoch, use 0.2876 s, loss 0.3742746412754059, acc 0.8639591336250305
Train 8 epoch, use 9.0806 s, loss 0.3381759822368622, acc 0.8756925463676453
Test 8 epoch, use 0.2951 s, loss 0.3876565992832184, acc 0.8559110760688782
Train 9 epoch, use 9.1019 s, loss 0.32923951745033264, acc 0.8797928094863892
Test 9 epoch, use 0.2833 s, loss 0.3779725432395935, acc 0.8602309226989746
import matplotlib.pyplot as plt
plt.plot(trainer.get_hist_metric('fit'), label='fit')
plt.plot(trainer.get_hist_metric('test'), label='train')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

Train a SNN model#
Similarly, brainpy.train.BPTT
can also be used to train SNN models.
We first build a three layer SNN model:
i >> [exponential synapse] >> r >> [exponential synapse] >> o
class SNNModel(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_rec, num_out):
super(SNNModel, self).__init__()
# parameters
self.num_in = num_in
self.num_rec = num_rec
self.num_out = num_out
# neuron groups
self.i = bp.neurons.InputGroup(num_in)
self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)
# synapse: i->r
self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
output=bp.synouts.CUBA(),
tau=10.,
g_max=bp.init.KaimingNormal(scale=2.))
# synapse: r->o
self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
output=bp.synouts.CUBA(),
tau=10.,
g_max=bp.init.KaimingNormal(scale=2.))
def update(self, shared, spike):
self.i2r(shared, spike)
self.r2o(shared)
self.r(shared)
self.o(shared)
return self.o.V.value
As the model receives spiking inputs, we define functions that are necessary to transform the continuous values to spiking data.
def current2firing_time(x, tau=20., thr=0.2, tmax=1.0, epsilon=1e-7):
x = np.clip(x, thr + epsilon, 1e9)
T = tau * np.log(x / (x - thr))
T = np.where(x < thr, tmax, T)
return T
def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True):
labels_ = np.array(y, dtype=bm.int_)
sample_index = np.arange(len(X))
# compute discrete firing times
tau_eff = 2. / bm.get_dt()
unit_numbers = np.arange(nb_units)
firing_times = np.array(current2firing_time(X, tau=tau_eff, tmax=nb_steps), dtype=bm.int_)
if shuffle:
np.random.shuffle(sample_index)
counter = 0
number_of_batches = len(X) // batch_size
while counter < number_of_batches:
batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]
all_batch, all_times, all_units = [], [], []
for bc, idx in enumerate(batch_index):
c = firing_times[idx] < nb_steps
times, units = firing_times[idx][c], unit_numbers[c]
batch = bc * np.ones(len(times), dtype=bm.int_)
all_batch.append(batch)
all_times.append(times)
all_units.append(units)
all_batch = np.concatenate(all_batch).flatten()
all_times = np.concatenate(all_times).flatten()
all_units = np.concatenate(all_units).flatten()
x_batch = bm.zeros((batch_size, nb_steps, nb_units))
x_batch[all_batch, all_times, all_units] = 1.
y_batch = bm.asarray(labels_[batch_index])
yield x_batch, y_batch
counter += 1
Now, we can define a BP trainer for this SNN model.
def loss_fun(predicts, targets):
predicts, mon = predicts
# L1 loss on total number of spikes
l1_loss = 1e-5 * bm.sum(mon['r.spike'])
# L2 loss on spikes per neuron
l2_loss = 1e-5 * bm.mean(bm.sum(bm.sum(mon['r.spike'], axis=0), axis=0) ** 2)
# predictions
predicts = bm.max(predicts, axis=1)
loss = bp.losses.cross_entropy_loss(predicts, targets)
acc = bm.mean(predicts.argmax(-1) == targets)
return loss + l2_loss + l1_loss, {'acc': acc}
model = SNNModel(num_in=28*28, num_rec=100, num_out=10)
trainer = bp.train.BPTT(
model,
loss_fun=loss_fun,
loss_has_aux=True,
optimizer=bp.optim.Adam(lr=1e-3),
monitors={'r.spike': model.r.spike},
)
The training process is similar to that of the ANN model, instead of the data is generated by the sparse generator function we defined above.
x_train = bm.array(train_dataset.data, dtype=bm.float_) / 255
y_train = bm.array(train_dataset.targets, dtype=bm.int_)
trainer.fit(lambda: sparse_data_generator(x_train.reshape(x_train.shape[0], -1),
y_train,
batch_size=256,
nb_steps=100,
nb_units=28 * 28),
num_epoch=10)
Train 0 epoch, use 56.6148 s, loss 10.524602890014648, acc 0.3441840410232544
Train 1 epoch, use 48.7201 s, loss 1.947080373764038, acc 0.4961271286010742
Train 2 epoch, use 50.2106 s, loss 1.5027152299880981, acc 0.5980067849159241
Train 3 epoch, use 53.0944 s, loss 1.371555209159851, acc 0.63353031873703
Train 4 epoch, use 54.2528 s, loss 1.294083833694458, acc 0.6476696133613586
Train 5 epoch, use 56.5207 s, loss 1.2385631799697876, acc 0.6586705446243286
Train 6 epoch, use 61.7909 s, loss 1.2144725322723389, acc 0.6649806499481201
Train 7 epoch, use 72.7359 s, loss 1.1915594339370728, acc 0.6712072491645813
Train 8 epoch, use 76.2446 s, loss 1.153993010520935, acc 0.6776843070983887
Train 9 epoch, use 79.4869 s, loss 1.1312021017074585, acc 0.682542085647583
plt.plot(trainer.get_hist_metric('fit'), label='fit')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

Customize your BP training#
Actually, brainpy.train.BPTT
is just one way to perform back-propagation training with your model. You can easily customize your training process.
In the below, we demonstrate how to define a BP training process by hand with the above ANN model.
# packages we need
from time import time
# define the model
model = ANNModel(28, 100, 10)
# define the loss function
@bm.to_object(child_objs=model)
def loss_fun(inputs, targets):
runner = bp.train.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs, reset_state=True)
predicts = bm.max(predicts, axis=1)
loss = bp.losses.cross_entropy_loss(predicts, targets)
acc = bm.mean(predicts.argmax(-1) == targets)
return loss, acc
# define the gradient function which computes the
# gradients of the trainable weights
grad_fun = bm.grad(loss_fun,
grad_vars=model.train_vars().unique(),
has_aux=True,
return_value=True)
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
# training function
@bm.jit
@bm.to_object(child_objs=(opt, grad_fun))
def train(xs, ys):
grads, loss, acc = grad_fun(xs, ys)
opt.update(grads)
return loss, acc
# start training
k = 0
num_batch = 256
running_loss = 0
running_acc = 0
print_step = 100
X_train = bm.asarray(x_train)
Y_train = bm.asarray(y_train)
t0 = time()
for _ in range(10): # number of epoch
X_train = bm.random.permutation(X_train, key=123)
Y_train = bm.random.permutation(Y_train, key=123)
for i in range(0, X_train.shape[0], num_batch):
X = X_train[i: i + num_batch]
Y = Y_train[i: i + num_batch]
loss_, acc_ = train(X, Y)
running_loss += loss_
running_acc += acc_
k += 1
if k % print_step == 0:
print('Step {}, Used {:.4f} s, Loss {:0.4f}, Acc {:0.4f}'.format(
k, time() - t0, running_loss / print_step, running_acc / print_step)
)
t0 = time()
running_loss = 0
running_acc = 0
Step 100, Used 6.7523 s, Loss 1.2503, Acc 0.5630
Step 200, Used 5.3020 s, Loss 0.6340, Acc 0.7779
Step 300, Used 6.5825 s, Loss 0.5545, Acc 0.8056
Step 400, Used 5.3013 s, Loss 0.5028, Acc 0.8198
Step 500, Used 5.3458 s, Loss 0.4659, Acc 0.8340
Step 600, Used 5.3190 s, Loss 0.4601, Acc 0.8316
Step 700, Used 5.2990 s, Loss 0.4297, Acc 0.8443
Step 800, Used 5.3577 s, Loss 0.4244, Acc 0.8456
Step 900, Used 5.3054 s, Loss 0.4053, Acc 0.8538
Step 1000, Used 5.3404 s, Loss 0.3913, Acc 0.8568
Step 1100, Used 5.2744 s, Loss 0.3943, Acc 0.8534
Step 1200, Used 5.4739 s, Loss 0.3863, Acc 0.8592
Step 1300, Used 5.4073 s, Loss 0.3709, Acc 0.8647
Step 1400, Used 5.3310 s, Loss 0.3791, Acc 0.8607
Step 1500, Used 5.3793 s, Loss 0.3644, Acc 0.8643
Step 1600, Used 5.3164 s, Loss 0.3562, Acc 0.8718
Step 1700, Used 5.4404 s, Loss 0.3585, Acc 0.8677
Step 1800, Used 5.4584 s, Loss 0.3533, Acc 0.8716
Step 1900, Used 5.4216 s, Loss 0.3460, Acc 0.8727
Step 2000, Used 5.4207 s, Loss 0.3445, Acc 0.8729
Step 2100, Used 5.3493 s, Loss 0.3375, Acc 0.8749
Step 2200, Used 5.3991 s, Loss 0.3317, Acc 0.8773
Step 2300, Used 5.3003 s, Loss 0.3356, Acc 0.8755
Introduction to Echo State Network#
import brainpy as bp
import brainpy.math as bm
# enable x64 computation
bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')
bp.__version__
'2.3.1'
import brainpy_datasets as bd
bd.__version__
'0.0.0.2'
import matplotlib.pyplot as plt
Echo State Network#
Echo State Networks (ESNs) are applied to supervised temporal machine learning tasks where for a given training input signal \(x(n)\) a desired target output signal \(y^{target}(n)\) is known. Here \(n=1, ..., T\) is the discrete time and \(T\) is the number of data points in the training dataset.
The task is to learn a model with output \(y(n)\), where \(y(n)\) matches \(y^{target}(n)\) as well as possible, minimizing an error measure \(E(y, y^{target})\), and, more importantly, generalizes well to unseen data.
ESNs use an RNN type with leaky-integrated discrete-time continuous-value units. The typical update equations are
where \(h(n)\) is a vector of reservoir neuron activations, \(W^{in}\) and \(W^{rec}\) are the input and recurrent weight matrices respectively, and \(\alpha \in (0, 1]\) is the leaking rate. The model is also sometimes used without the leaky integration, which is a special case of \(\alpha=1\).
The linear readout layer is defined as
where \(y(n)\) is network output, \(W^{out}\) the output weight matrix, and \(b^{out}\) is the output bias.
An additional nonlinearity can be applied to \(y(n)\), as well as feedback connections \(W^{fb}\) from \(y(n-1)\) to \(\hat{h}(n)\).
A graphical representation of an ESN illustrating our notation and the idea for training is depicted in the following figure.
Ridge regression#
Finding the optimal weights \(W^{out}\) that minimize the squared error between \(y(n)\) and \(y^{target}(n)\) amounts to solving a typically overdetermined system of linear equations
Probably the most universal and stable solution is ridge regression, also known as regression with Tikhonov regularization:
Dataset#
Mackey-Glass equation are a set of delayed differential equations describing the temporal behaviour of different physiological signal, for example, the relative quantity of mature blood cells over time.
The equations are defined as:
where \(\beta = 0.2\), \(\gamma = 0.1\), \(n = 10\), and the time delay \(\tau = 17\). \(\tau\) controls the chaotic behaviour of the equations (the higher it is, the more chaotic the timeserie becomes. \(\tau=17\) already gives good chaotic results.)
def plot_mackey_glass_series(ts, x_series, x_tau_series, num_sample):
plt.figure(figsize=(13, 5))
plt.subplot(121)
plt.title(f"Timeserie - {num_sample} timesteps")
plt.plot(ts[:num_sample], x_series[:num_sample], lw=2, color="lightgrey", zorder=0)
plt.scatter(ts[:num_sample], x_series[:num_sample], c=ts[:num_sample], cmap="viridis", s=6)
plt.xlabel("$t$")
plt.ylabel("$P(t)$")
ax = plt.subplot(122)
ax.margins(0.05)
plt.title(f"Phase diagram: $P(t) = f(P(t-\\tau))$")
plt.plot(x_tau_series[: num_sample], x_series[: num_sample], lw=1, color="lightgrey", zorder=0)
plt.scatter(x_tau_series[:num_sample], x_series[: num_sample], lw=0.5, c=ts[:num_sample], cmap="viridis", s=6)
plt.xlabel("$P(t-\\tau)$")
plt.ylabel("$P(t)$")
cbar = plt.colorbar()
cbar.ax.set_ylabel('$t$')
plt.tight_layout()
plt.show()
An easy way to get Mackey-Glass time-series data is using brainpy.dataset.mackey_glass_series()
. If you want to see the details of the implementation, please see the corresponding source code.
dt = 0.1
mg_data = bd.chaos.MackeyGlassEq(25000, dt=dt, tau=17, beta=0.2, gamma=0.1, n=10)
ts = mg_data.ts
xs = mg_data.xs
ys = mg_data.ys
plot_mackey_glass_series(ts, xs, ys, num_sample=int(1000 / dt))

Task 1: prediction of Mackey-Glass timeseries#
Predict \(P(t+1), \cdots, P(t+N)\) from \(P(t)\).
Prepare the data#
def get_data(t_warm, t_forcast, t_train, sample_rate=1):
warmup = int(t_warm / dt) # warmup the reservoir
forecast = int(t_forcast / dt) # predict 10 ms ahead
train_length = int(t_train / dt)
X_warm = xs[:warmup:sample_rate]
X_warm = bm.expand_dims(X_warm, 0)
X_train = xs[warmup: warmup+train_length: sample_rate]
X_train = bm.expand_dims(X_train, 0)
Y_train = xs[warmup+forecast: warmup+train_length+forecast: sample_rate]
Y_train = bm.expand_dims(Y_train, 0)
X_test = xs[warmup + train_length: -forecast: sample_rate]
X_test = bm.expand_dims(X_test, 0)
Y_test = xs[warmup + train_length + forecast::sample_rate]
Y_test = bm.expand_dims(Y_test, 0)
return X_warm, X_train, Y_train, X_test, Y_test
# First warmup the reservoir using the first 100 ms
# Then, train the network in 20000 ms to predict 1 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(100, 1, 20000)
sample = 3000
fig = plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()

Prepare the ESN#
class ESN(bp.dyn.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out, sr=1., leaky_rate=0.3,
Win_initializer=bp.init.Uniform(0, 0.2)):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(
num_in, num_hidden,
Win_initializer=Win_initializer,
spectral_radius=sr,
leaky_rate=leaky_rate,
)
self.o = bp.layers.Dense(num_hidden, num_out, mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
Train and test ESN#
model = ESN(1, 100, 1)
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-6)
# warmup
_ = trainer.predict(x_warm)
# train
_ = trainer.fit([x_train, y_train])
Test the training data.
ys_predict = trainer.predict(x_train)
start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0],
lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_train)[0, start:end, 0], linestyle="--",
lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_train)}')
plt.legend()
plt.show()

Test the testing data.
ys_predict = trainer.predict(x_test)
start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0], lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_test)[0,start:end, 0], linestyle="--", lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_test)}')
plt.legend()
plt.show()

Make the task harder#
# First warmup the reservoir using the first 100 ms
# Then, train the network in 20000 ms to predict 10 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(100, 10, 20000)
sample = 3000
plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()

model = ESN(1, 100, 1, sr=1.1)
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-6)
# warmup
_ = trainer.predict(x_warm)
# train
_ = trainer.fit([x_train, y_train])
ys_predict = trainer.predict(x_test, )
start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0], lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_test)[0, start:end, 0], linestyle="--", lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_test)}')
plt.legend()
plt.show()

Diving into the reservoir#
Let’s have a look at the effect of some of the hyperparameters of the ESN.
Spectral radius#
The spectral radius is defined as the maximum eigenvalue of the reservoir matrix.
num_sample = 20
all_radius = [-0.5, 0.5, 1.25, 2.5, 10.]
plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_radius):
model = ESN(1, 100, 1, sr=s)
model.reset_state(1)
runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
_ = runner.predict(x_test[:, :10000])
states = bm.as_numpy(runner.mon['state'])
plt.subplot(len(all_radius), 1, i + 1)
plt.plot(states[0, :, :num_sample])
plt.ylabel(f"spectral radius=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()

spectral radius < 1 \(\rightarrow\) stable dynamics
spectral radius > 1 \(\rightarrow\) chaotic dynamics
In most cases, it should have a value around \(1.0\) to ensure the echo state property (ESP): the dynamics of the reservoir should not be bound to the initial state chosen, and remains close to chaos.
This value also heavily depends on the input scaling.
Input scaling#
The input scaling controls how the ESN interact with the inputs. It is a coefficient appliyed to the input matrix \(W^{in}\).
num_sample = 20
all_input_scaling = [0.1, 1.0, 10.0]
plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_input_scaling):
model = ESN(1, 100, 1, sr=1., Win_initializer=bp.init.Uniform(max_val=s))
model.reset_state(1)
runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
_ = runner.predict(x_test[:, :10000])
states = bm.as_numpy(runner.mon['state'])
plt.subplot(len(all_radius), 1, i + 1)
plt.plot(states[0, :, :num_sample])
plt.ylabel(f"input scaling=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()

Leaking rate#
The leaking rate (\(\alpha\)) controls the “memory feedback” of the ESN. The ESN states are indeed computed as:
where \(h\) is the state, \(x\) is the input data, \(f\) is the ESN model function, defined as:
\(\alpha\) must be in \([0, 1]\).
num_sample = 20
all_rates = [0.001, 0.01, 0.1, 1.]
plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_rates):
model = ESN(1, 100, 1, sr=1., leaky_rate=s,
Win_initializer=bp.init.Uniform(max_val=1.), )
model.reset_state(1)
runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
_ = runner.predict(x_test[:, :10000])
states = bm.as_numpy(runner.mon['state'])
plt.subplot(len(all_radius), 1, i + 1)
plt.plot(states[0, :, :num_sample])
plt.ylabel(f"leaky rate=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()

Let’s reduce the input influence to see what is happening inside the reservoir (input scaling set to 0.2):
num_sample = 20
all_rates = [0.001, 0.01, 0.1, 1.]
plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_rates):
model = ESN(1, 100, 1, sr=1., leaky_rate=s,
Win_initializer=bp.init.Uniform(max_val=.2), )
model.reset_state(1)
runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
_ = runner.predict(x_test[:, :10000])
states = bm.as_numpy(runner.mon['state'])
plt.subplot(len(all_radius), 1, i + 1)
plt.plot(states[0, :, :num_sample])
plt.ylabel(f"leaky rate=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()

high leaking rate \(\rightarrow\) low inertia, little memory of previous states
low leaking rate \(\rightarrow\) high inertia, big memory of previous states
The leaking rate can be seen as the inverse of the reservoir’s time contant.
Task 2: generation of Mackey-Glass timeseries#
Generative mode: the output of ESN will be used as the input.
During this task, the ESN is trained to make a short forecast of the timeserie (1 timestep ahead). Then, it will be asked to run on its own outputs, trying to predict its own behaviour.
# First warmup the reservoir using the first 500 ms
# Then, train the network in 20000 ms to predict 1 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(500, 1, 20000, sample_rate=int(1/dt))
sample = 300
fig = plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()

model = ESN(1, 100, 1, sr=1.1, Win_initializer=bp.init.Uniform(max_val=.2), )
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-7)
# warmup
_ = trainer.predict(x_warm)
# train
trainer.fit([x_train, y_train])
# test
ys_predict = trainer.predict(x_train)
start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0],
lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_train)[0, start:end, 0], linestyle="--",
lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_train)}')
plt.legend()
plt.show()

jit_model = bm.jit(model)
outputs = [x_test[:, 0]]
truths = [x_test[:, 1]]
for i in range(200):
outputs.append(jit_model(dict(), outputs[-1]))
truths.append(x_test[:, i+2])
outputs = bm.asarray(outputs)
truths = bm.asarray(truths)
plt.figure(figsize=(15, 10))
plt.plot(bm.as_numpy(truths).squeeze()[:200], label='Ground truth')
plt.plot(bm.as_numpy(outputs).squeeze()[:200], label='Prediction')
plt.legend()
plt.show()

References#
Jaeger, H.: The “echo state” approach to analysing and training recurrent neural networks. Technical Report GMD Report 148, German National Research Center for Information Technology (2001)
Lukoševičius, Mantas. “A Practical Guide to Applying Echo State Networks.” Neural Networks: Tricks of the Trade (2012).
Model Analysis#
Low-dimensional Analyzers#
We have talked about model simulation and training for dynamical systems with BrainPy. In this tutorial, we are going to dive into how to perform automatic analysis for your defined systems.
As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:
phase plane analysis;
codimension 1 or codimension 2 bifurcation analysis;
bifurcation analysis of the fast-slow system.
BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.
import brainpy as bp
import brainpy.math as bm
bm.enable_x64() # It's better to enable x64 when performing analysis
bm.set_platform('cpu')
bp.__version__
'2.3.0'
import numpy as np
import matplotlib.pyplot as plt
A simple case#
Here we test BrainPy with a simple case:
where \(x \in [-10, 10]\).
As known to us all, this function has multiple fixed points (\(\frac{dx}{dt} = 0\)) when \(I=0\).
xs = np.arange(-10, 10, 0.01)
plt.plot(xs, np.sin(xs))
plt.scatter([-3*np.pi, -1*np.pi, 1*np.pi, 3 * np.pi], np.zeros(4), s=80, edgecolors='y')
plt.scatter([-2*np.pi, 0, 2*np.pi], np.zeros(3), s=80, facecolors='none', edgecolors='r')
plt.axhline(0)
plt.show()

According to the dynamical theory, at the red hollow points, they are unstable; and for the solid ones, they are stable points.
Now let’s come back to BrainPy, and test whether BrainPy can give us the right answer.
As the analysis interfaces in BrainPy only receives ODEIntegrator or instance of DynamicalSystem, we first define an integrator with BrainPy (if you want to know how to define an ODE integrator, please refer to the tutorial of Numerical Solvers for ODEs):
@bp.odeint
def int_x(x, t, Iext):
return bp.math.sin(x) + Iext
This is a one-dimensional dynamical system. So we are trying to use brainpy.analysis.PhasePlane1D for phase plane analysis. The usage of phase plane analysis will be detailed in the following section. Now, we just focus on the following four arguments:
model: It specifies the target system to analyze. It can be a list/tuple of ODEIntegrator. However, it can also be an instance of DynamicalSystem. For
DynamicalSystem
argument, we will usemodel.ints().subset(bp.ode.ODEIntegrator)
to retrieve all instances of ODEIntegrator later.target_vars: It specifies the variables to analyze. It must be a dict with the format of
<var_name, var_interval>
, wherevar_name
is the variable name, andvar_interval
is the boundary of this variable.pars_update: Parameters to update.
resolutions: The resolution to evaluate the fixed points.
Let’s try it.
pp = bp.analysis.PhasePlane1D(
model=int_x,
target_vars={'x': [-10, 10]},
pars_update={'Iext': 0.},
resolutions={'x': 0.01}
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)
I am creating the vector field ...
I am searching fixed points ...
Fixed point #1 at x=-9.424777960769386 is a stable point.
Fixed point #2 at x=-6.283185307179586 is a unstable point.
Fixed point #3 at x=-3.1415926535897984 is a stable point.
Fixed point #4 at x=3.552755127361717e-18 is a unstable point.
Fixed point #5 at x=3.1415926535897984 is a stable point.
Fixed point #6 at x=6.283185307179586 is a unstable point.
Fixed point #7 at x=9.424777960769386 is a stable point.

Yeah, absolutelty, brainpy.analysis.PhasePlane1D
gives us the right fixed points, and correctly evaluates the stability of these fixed points.
Phase plane is important, because it gives us the intuitive understanding how the system evolves with the given parameters. However, in most cases where we care about how the parameters affect the system behaviors, we should make bifurcation analysis. brainpy.analysis.Bifurcation1D is a convenient interface to help you get the insights of how the dynamics of a 1D system changes with parameters.
Similar to brainpy.analysis.PhasePlane1D
, brainpy.analysis.Bifurcation1D
receives arguments like “model”, “target_vars”, “pars_update”, and “resolutions”. Besides, one more important argument “target_pars” should be provided, which specifies the range of the target parameter in bifurcation analysis.
Here, we systematically change the parameter “Iext” from 0 to 1.5. According to the bifurcation theory, we know this simple system has a fold bifurcation when \(I=1.0\). Because at \(I=1.0\), two fixed points collide with each other into a saddle point and then disappear. Does BrainPy’s analysis toolkit brainpy.analysis.Bifurcation1D
is capable of performing these analyses? Let’s make a try.
bif = bp.analysis.Bifurcation1D(
model=int_x,
target_vars={'x': [-10, 10]},
target_pars={'Iext': [0., 1.5]},
resolutions={'Iext': 0.005, 'x': 0.05}
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...

Once again, BrainPy analysis toolkit gives the right answer. It tells us how does the fixed points evolve when the parameter \(I\) is increasing.
It is worthy to note that bifurcation analysis in BrainPy is hard to find out the saddle point (when \(I=0\) for this system). This is because the saddle point at the bifurcation just exists at a moment. While the numerical method used in BrainPy analysis toolkit is almost impossible to evaluate the point exactly at the saddle. However, if the user has the minimal knowledge about the bifurcation theory, saddle point (the collision point of two fixed points) can be easily inferred from the fixed point evolution.
BrainPy’s analysis toolkit is highly useful, especially when the mathematical equations are too complex to get analytical solutions. The example please refer to the tutorial Anlysis of A Decision Making Model.
Phase plane analysis#
Phase plane analysis is one of the most important techniques for studying the behavior of nonlinear systems, since there is usually no analytical solution for a nonlinear system. BrainPy can help users to plot phase plane of 1D systems or 2D systems. Specifically, we provide brainpy.analysis.PhasePlane1D and brainpy.analysis.PhasePlane2D. It can help to plot:
Nullcline: The zero-growth isoclines, such as \(g(x, y)=0\) and \(g(x, y)=0\).
Fixed points: The equilibrium points of the system, which are located at all the nullclines intersect.
Vector field: The vector field of the system.
Limit cycles: The limit cycles.
Trajectories: A simulation trajectory with the given initial values.
We have talked about brainpy.analysis.PhasePlane1D
in above. Now we focus on brainpy.analysis.PhasePlane2D
by using a well-known neuron model FitzHugh-Nagumo model.
The FitzHugh-Nagumo model is given by:
There are two variables \(V\) and \(w\), so this is a two-dimensional system with three parameters \(a, b\) and \(\tau\).
For the system to analyze, users can define it by using the pure brainpy.odeint
or define it as a class of DynamicalSystem
. For this FitzHugh-Nagumo model, we define it as a class because later we will perform simulation to verify the analysis results.
class FitzHughNagumoModel(bp.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(FitzHughNagumoModel, self).__init__()
# parameters
self.a = 0.7
self.b = 0.8
self.tau = 12.5
# variables
self.V = bm.Variable(bm.zeros(1))
self.w = bm.Variable(bm.zeros(1))
self.Iext = bm.Variable(bm.zeros(1))
# functions
def dV(V, t, w, Iext=0.):
return V - V * V * V / 3 - w + Iext
def dw(w, t, V, a=0.7, b=0.8):
return (V + a - b * w) / self.tau
self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)
def update(self, tdi):
self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
self.w.value = self.int_w(self.w, tdi.t, self.V, self.a, self.b, tdi.dt)
self.Iext[:] = 0.
model = FitzHughNagumoModel()
Here we perform a phase plane analysis with parameters \(a=0.7, b=0.8, \tau=12.5\), and input \(I_{ext} = 0.8\).
pp = bp.analysis.PhasePlane2D(
model,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
pars_update={'Iext': 0.8},
resolutions={'V': 0.01, 'w': 0.01},
)
# By defaut, nullclines will be plotted as points,
# while we can set the plot style as the line
pp.plot_nullcline(x_style={'fmt': '-'}, y_style={'fmt': '-'})
# Vector field can plotted as two ways:
# - plot_method="streamplot" (default)
# - plot_method="quiver"
pp.plot_vector_field()
# There are many ways to search fixed points.
# By default, it will use the nullcline points of the first
# variable ("V") as the initial points to perform fixed point searching
pp.plot_fixed_point()
# Trajectory plotting receives the setting of the initial points.
# There may be multiple trajectories, therefore the initial points
# should be provived as a list/tuple/numpy.ndarray/Array
pp.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
# show the phase plane figure
pp.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 866 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 V=-0.2729223248464073, w=0.5338542697673022 is a unstable node.
I am plotting the trajectory ...

We can see an unstable-node at the point (\(V=-0.27, w=0.53\)) inside a limit cycle.
We can run a simulation with the same parameters and initial values to verify the periodic activity that correspond to the limit cycle.
runner = bp.DSRunner(model, monitors=['V', 'w'], inputs=['Iext', 0.8])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)

Understanding settings#
There are several key settings needed to understand.
resolutions
#
resolutions
is one of the most important parameters in PhasePlane and Bifurcation analysis toolkits of BrainPy. It is very important because it has a profound impact on the efficiency of model analysis.
We can set resolutions
with the following ways.
None. If we detect there is no resolution setting for any variable, the corresponding resolution for this variable will be \(\frac{\mathrm{max\_value} - \mathrm{min\_value}}{20}\).
A float. It sets a same resolution for each target variable and parameter.
A dict. Specify different resolutions for individual variable/parameter. It can be a float, or a vector with the format of Array or numpy.ndarray.
Note
It is highly recommended that users specify the resolution to specific parameters or variables by a dict rather than set a float value, which will be applied to all variables. Otherwise, the computation will occupy too much memory if the resolution is set very small. For example, if you want to set the resolution of variable x
as 0.01, please use resolutions={'x': 0.01}
.
Enabling set resolutions
with a tensor will give the user the maximal flexibility. Usually, the numerical analysis does not work well at inflection points. Therefore, we can increase the granularity near the inflection points. For example, if there is an inflection point at \(1\), we can set the resolution with:
r1 = bm.arange(0.00, 0.95, 0.01)
r2 = bm.arange(0.95, 1.01, 0.001)
r3 = bm.arange(1.05, 1.50, 0.01)
resolution = bm.concatenate([r1, r2, r3])
Tips: For bifurcation analysis, usually we need set a small resolution for parameters, leaving the resolutions of variables as the default. Please see in the following examples.
vars
and pars
#
What can be set as variables *_vars
or parameters *_pars
(such as target_vars
or target_pars
) for further analysis? Actually, the variables and parameters are recognized as the same with the programming paradigm of ODE numerical integrators. Simply speaking, the arguments before t
will be defined as variables, while arguments after t
will be parameters.
BrainPy’s analysis toolkit only support one variable in one differential equation. It cannot analyze the joint differential equation in which multiple variables are defined in the same function.
Moreover, the low-dimensional analyzers in BrainPy cannot analyze dynamical system depends on time \(t\).
Bifurcation analysis#
Nonlinear dynamical systems are characterized by its parameters. When the parameter changes, the system’s behavior will change qualitatively. Therefore, we take care of how the system changes with the smooth change of parameters.
Codimension 1 bifurcation analysis
We will first see the codimension 1 bifurcation analysis of the model. For example, we vary the input \(I_{ext}\) between 0 and 1 and see how the system change its stability.
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
target_pars={'Iext': [0., 1.]},
resolutions={'Iext': 0.002},
)
# "num_rank" specifies the number of initial poinits for
# fixed point optimization under a set of parameters
analyzer.plot_bifurcation(num_rank=10)
# show figure
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 5000 candidates
I am trying to filter out duplicate fixed points ...
Found 500 fixed points.


Codimension 2 bifurcation analysis
We simulaneously change \(I_{ext}\) and parameter \(a\).
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars=dict(V=[-3, 3], w=[-3., 3.]),
target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
resolutions={'a': 0.01, 'Iext': 0.01},
)
analyzer.plot_bifurcation(num_rank=10, tol_aux=1e-9)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 50000 candidates
I am trying to filter out duplicate fixed points ...
Found 4997 fixed points.


Fast-slow system bifurcation#
BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.analysis.FastSlow1D and brainpy.analysis.FastSlow2D. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.
Fast-slow bifurcation methods are very useful in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:
First, let’s define the Hindmarsh–Rose model with BrainPy.
a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9
@bp.odeint
def int_x(x, t, y, z, Isyn):
return y - a * x ** 3 + b * x * x - z + Isyn
@bp.odeint
def int_y(y, t, x):
return c - d * x * x - y
@bp.odeint
def int_z(z, t, x):
return r * (s * (x - x_r) - z)
We now can start to analysis the underlying bifurcation mechanism.
analyzer = bp.analysis.FastSlow2D(
[int_x, int_y, int_z],
fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
slow_vars={'z': [-5., 5.]},
pars_update={'Isyn': 0.5},
resolutions={'z': 0.01}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 20000 candidates
I am trying to filter out duplicate fixed points ...
Found 1156 fixed points.


References#
[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.
[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.
[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.
High-dimensional Analyzers#
It’s hard to analyze high-dimensional systems. However, we have to analyze high-dimensional systems.
Here, based on numerical optimization methods, BrainPy provides brainpy.analysis.SlowPointFinder to help users find slow points (or fixed points) [1] for your high-dimensional dynamical systems.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.3.0'
What are slow points?#
For the given system,
we wish to find values \(x^∗\) around which the system is approximately linear. Using Taylor series expansion, we have
We want the first derivative term (i.e., the linear term) to be dominant, which means \(f(x^*) = 0\) or \(f(x^*) \approx 0\).
For \(f(x^*) \approx 0\) which is nonzero but small, we call the point \(x^*\) a slow point.
More specially, if \(f(x^*) = 0\), \(x^*\) is a fixed point.
How to find slow points?#
In order to find slow points, we can first define an auxiliary scalar function for your continous system \(\dot{x} = f(x)\),
Or, if your system is discrete \(x_n = f(x_{n-1})\), the auxiliary scalar function can be defined as
If \(x^*\) is a slow point, \(p(x^*) \to 0\).
Then, by minimizing the scalar function \(p(x)\), we can get the candidate points for slow points and for further linearization. For the linear system, it’s stability is evaluated by the eigenvalues of Jacobian matrix.
Here, BrainPy provides brainpy.analysis.SlowPointFinder. It receives f_cell
to specify the target function/object to analyze.
If the provided f_cell
is a function, SlowPointFinder
can supports to specify:
f_type
: the type of the function (it can be “continuous” or “discrete”).f_loss
: the loss function to minimize the optimization error.args
: extra arguments passed into the function when performing fixed point optimization.
If the provided f_cell
is an instance of DynamicalSystem
, SlowPointFinder
can supports to specify:
f_loss
: the loss function to minimize the optimization error.args
: extra arguments passed into the definedupdate()
function when performing fixed point optimization.inputs
andfun_inputs
: inputs to this dynamical system. Similar to the inputs ofDSRunner
andDSTrainer
.target_vars
: the selected variables which are used to optimize fixed points. Other variables like “input” and “spike” can be ignored.
Then, brainpy.analysis.SlowPointFinder
can help you:
optimize to find the fixed/slow points with gradient descent algorithms (
find_fps_with_gd_method()
) or nonlinear optimization solver (find_fps_with_opt_solver()
)exclude any fixed points whose losses are above threshold:
filter_loss()
exclude any non-unique fixed points according to a tolerance:
keep_unique()
exclude any far-away “outlier” fixed points:
exclude_outliers()
computing the jacobian matrix for the given fixed/slow points:
compute_jacobians()
Example 1: Decision Making Model#
brainpy.analysis.SlowPointFinder
is aimed to find slow/fixed points of high-dimensional systems. Of course, it can optimize to find fixed points of low-dimensional systems. We take the 2D decision-making system as an example.
# parameters
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154
JE = 0.3725 # self-coupling strength [nA]
JI = -0.1137 # cross-coupling strength [nA]
JAext = 0.00117 # Stimulus input strength [nA]
mu = 20. # Stimulus firing rate [spikes/sec]
coh = 0.5 # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2
def step(s):
ds1 = int_s1.f(s[0], 0., s[1])
ds2 = int_s2.f(s[1], 0., s[0])
return bm.asarray([ds1, ds2])
We first use brainpy.analysis.PhasePlane2D to get the standard answer.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
resolutions=0.001,
)
analyzer.plot_fixed_point(select_candidates='aux_rank', with_plot=False)
I am searching fixed points ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 100 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.2827633321285248, s2=0.40635180473327637 is a saddle node.
#2 s1=0.013946513645350933, s2=0.6573889851570129 is a stable node.
#3 s1=0.7004518508911133, s2=0.004864312242716551 is a stable node.
Then, let’s check whether the high-dimensional analyzer also works.
finder = bp.analysis.SlowPointFinder(f_cell=step, f_type="continuous")
finder.find_fps_with_gd_method(
candidates=bm.random.random((1000, 2)),
tolerance=1e-5,
num_batch=200,
optimizer=bp.optimizers.Adam(bp.optimizers.ExponentialDecay(0.01, 1, 0.9999))
)
finder.filter_loss(1e-5)
finder.keep_unique()
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.24 sec, Training loss 0.0510474481
Batches 201-400 in 0.25 sec, Training loss 0.0046035680
Batches 401-600 in 0.32 sec, Training loss 0.0007384720
Batches 601-800 in 0.27 sec, Training loss 0.0001601687
Batches 801-1000 in 0.25 sec, Training loss 0.0000381663
Batches 1001-1200 in 0.25 sec, Training loss 0.0000088441
Stop optimization as mean training loss 0.0000088441 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
Kept 934/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
Kept 3/934 unique fixed points with uniqueness tolerance 0.025.
finder.fixed_points
array([[0.28276306, 0.40635154],
[0.7004519 , 0.00486429],
[0.01394659, 0.6573889 ]], dtype=float32)
Yeah, the fixed points found by brainpy.analysis.PhasePlane2D
and brainpy.analysis.SlowPointFinder
are nearly the same.
Example 2: Continuous-attractor Neural Network#
Continuous-attractor neural network [2] proposed by Si Wu is a special model which has a line of attractors.
class CANN1D(bp.dyn.NeuGroup):
def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-bm.pi, z_max=bm.pi):
super(CANN1D, self).__init__(size=num)
# parameters
self.tau = tau # The synaptic time constant
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, num) # The encoded feature values
self.rho = num / self.z_range # The neural density
self.dx = self.z_range / num # The stimulus density
# variables
self.u = bm.Variable(bm.zeros(num))
self.input = bm.Variable(bm.zeros(num))
# The connection matrix
self.conn_mat = self.make_conn(self.x)
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, u, t, Iext):
r1 = bm.square(u)
r2 = 1.0 + self.k * bm.sum(r1)
r = r1 / r2
Irec = bm.dot(self.conn_mat, r)
du = (-u + Irec + Iext) / self.tau
return du
def dist(self, d):
d = bm.remainder(d, self.z_range)
d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
return d
def make_conn(self, x):
assert bm.ndim(x) == 1
x_left = bm.reshape(x, (-1, 1))
x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
d = self.dist(x_left - x_right)
Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
return Jxx
def get_stimulus_by_pos(self, pos):
return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def update(self, tdi):
self.u.value = self.integral(self.u, tdi.t, self.input, tdi.dt)
self.input[:] = 0.
def cell(self, u):
return self.derivative(u, 0., 0.)
cann = CANN1D(num=512, k=0.1, A=30)
The following code demonstrates how to use SlowPointFinder
to find fixed points of a continuous attractor neural network.
# initialize an instance of slow point finder
finder = bp.analysis.SlowPointFinder(
f_cell=cann,
target_vars={'u': cann.u},
dt=1.,
)
# we can initialize our candidate points with noisy bumps.
candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1)))
candidates += bm.random.normal(0., 0.01, candidates.shape)
# optimize to find fixed points
finder.find_fps_with_opt_solver({'u': candidates})
finder.filter_loss(1e-6)
finder.keep_unique()
Optimizing with BFGS to find fixed points:
Found 629 fixed points from 629 initial points.
Excluding fixed points with squared speed above tolerance 1e-06:
Kept 357/629 fixed points with tolerance under 1e-06.
Excluding non-unique fixed points:
Kept 357/357 unique fixed points with uniqueness tolerance 0.025.
The found fixed points are a series of attractor. We can visualize this line of attractors on a 2D space.
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points['u'])
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()

These fixed points can also be plotted on the feature space. In the following, we plot the selected points.
def visualize_fixed_points(fps, plot_ids=(0,), xs=None):
for i in plot_ids:
if xs is None:
plt.plot(fps[i], label=f'FP-{i}')
else:
plt.plot(xs, fps[i], label=f'FP-{i}')
plt.legend()
plt.xlabel('Feature')
plt.ylabel('Bump activity')
plt.show()
visualize_fixed_points(finder.fixed_points['u'],
plot_ids=(10, 20, 30, 40, 50, 60, 70, 80),
xs=cann.x)

Let’s find the linear part or the Jacobian matrix around the fixed points. We decompose Jacobian matrix and then visualize its stability.
from jax import tree_map
# select the first ten fixed points
fps = tree_map(lambda a: a[:10], finder._fixed_points)
# compute jacobian and visualize the decomposed jacobian matrix
J = finder.compute_jacobians(fps, plot=True, num_col=2)

More examples of dynamics analysis, for example, analyzing the fixed points in a recurrent neural network, please see BrainPy Examples.
References#
[1] Sussillo, D. , and O. Barak . “Opening the Black Box: Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks.” Neural computation 25.3(2013):626-649.
[2] Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.
Analysis of a Decision-making Model#
In this section, we are going to use the low-dimensional analyzers to make phase plane and bifurcation analysis for the decision-making model proposed by (Wong & Wang) [1].
Decision making model#
This model considers two excitatory neural assemblies, populations 1 and 2 , that compete with each other through a shared pool of inhibitory neurons. In our analysis, we use the following model equations.
Let \(r_1\) and \(r_2\) be firing rates of E and I populations, and the total synaptic input current \(I_i\) and the resulting firing rate \(r_i\) of the neural population \(i\) obey the following input-output relationship (\(F - I\) curve):
which captures the current-frequency function of a leaky integrate-and-fire neuron. The parameter values are \(a\) = 270 Hz/nA, \(b\) = 108 Hz, \(d\) = 0.154 sec.
Assume that the synaptic drive variables’ \(S_1\) and \(S_2\) obey
where \(\gamma\) = 0.641. The net current into each population is given by
The synaptic time constant is \(\tau_s\) = 100 ms (NMDA time consant). The synaptic coupling strengths are \(J_E\) = 0.2609 nA, \(J_I\) = -0.0497 nA, and \(J_{ext}\) = 0.00052 nA. Stimulus-selective inputs to populations 1 and 2 are governed by unitless parameters \(\mu_1\) and \(\mu_2\), respectively.
For the decision-making paradigm, the input rates \(\mu_1\) and \(\mu_2\) are determined by the stimulus coherence \(c'\) which ranges between 0 (0%) and 1 (100%):
import brainpy as bp
import brainpy.math as bm
bp.math.enable_x64()
# bp.math.set_platform('cpu')
bp.__version__
'2.3.0'
Parameters#
gamma = 0.641 # Saturation factor for gating variable
tau = 0.1 # Synaptic time constant [sec]
a = 270. # Hz/nA
b = 108. # Hz
d = 0.154 # sec
I0 = 0.3255 # background current [nA]
JE = 0.2609 # self-coupling strength [nA]
JI = -0.0497 # cross-coupling strength [nA]
JAext = 0.00052 # Stimulus input strength [nA]
Ib = 0.3255 # The background input
Model implementation#
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2
Phase plane analysis#
The advantage of the reduced model is that we can understand what dynamical behaviors the model generate for a particular parmeter set using phase-plane analysis and the explore how this behavior changed when the model parameters are varied (bifurcation analysis).
To this end, we will use brainpy.analysis
module.
We construct the phase portraits of the reduced model for different stimulus inputs (see Figure 4 and Figure 5 in (Wong & Wang, 2006) [1]).
No stimulus: \(\mu_0 =0\) Hz. In the absence of a stimulus, the two nullclines intersect with each other five times, producing five steady states, of which three are stable (attractors) and two are unstable
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 0.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 5 fixed points.
#1 s1=0.5669871605297269, s2=0.031891419715715866 is a stable node.
#2 s1=0.31384492489136057, s2=0.05578533347184539 is a saddle node.
#3 s1=0.1026514458219984, s2=0.10265095098914433 is a stable node.
#4 s1=0.05578534267632889, s2=0.3138449310808786 is a saddle node.
#5 s1=0.03189144636489119, s2=0.5669870352865433 is a stable node.

Symmetric stimulus: \(\mu_0=30\) Hz, \(c'=0\). When a stimulus is applied, the phase space of the model is reconfigured. The spontaneous state vanishes. At the same time, a saddle-type unstable steady state is created that separates the two asymmetrical attractors.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 0.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.658694232143127, s2=0.05180719943991283 is a stable node.
#2 s1=0.42445578984858384, s2=0.4244556283731401 is a saddle node.
#3 s1=0.05180717720080605, s2=0.6586942355713474 is a stable node.

Biased stimulus: \(\mu_0=30\) Hz, \(c' = 0.14\) (14 % coherence). The phase space changes when a weak motion stimulus is presented. The phase space is no longer symmetrical: the attractor state s1 (correct choice) has a larger basin of attraction than attractor s2.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 0.14},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.6679776124172938, s2=0.04583022226100692 is a stable node.
#2 s1=0.3845586078985544, s2=0.4536309035289816 is a saddle node.
#3 s1=0.059110032802350894, s2=0.6481046659437735 is a stable node.

Stimulus to one population only: \(\mu_0=30\) Hz, \(c'=1.\) (100 % coherence). When \(c'\) is sufficiently large, the saddle steady state annihilates with the less favored attractor, leaving only one choice attractor.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 1.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 s1=0.7092805209334904, s2=0.02396366304199462 is a stable node.

Bifurcation analysis#
To see how the phase portrait of the system changed when we chang the stimulus current, we will generate a bifurcation diagram for the reduced model. On the bifurcation diagram the fixed points of the model are shown as a function of a changing parameter.
In the next, we generate bifurcation diagrams with the different parameters.
Fix the coherence \(c'=0\), vary the stimulus strength \(\mu_0\). See Figure 10 in (Wong & Wang, 2006) [1].
analyzer = bp.analysis.Bifurcation2D(
model=[int_s1, int_s2],
target_vars={'s1': [0., 1.], 's2': [0., 1.]},
target_pars={'mu': [-30., 90.]},
pars_update={'coh': 0.},
resolutions={'mu': 0.2},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 30000 candidates
I am trying to filter out duplicate fixed points ...
Found 1744 fixed points.


Fix the stimulus strength \(\mu_0 = 30\) Hz, vary the coherence \(c'\).
analyzer = bp.analysis.Bifurcation2D(
model=[int_s1, int_s2],
target_vars={'s1': [0., 1.], 's2': [0., 1.]},
target_pars={'coh': [0., 1.]},
pars_update={'mu': 30.},
resolutions={'coh': 0.005},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 10000 candidates
I am trying to filter out duplicate fixed points ...
Found 474 fixed points.


References#
[1] Wong K-F and Wang X-J (2006). A recurrent network mechanism for time integration in perceptual decisions. J. Neurosci 26, 1314-1328.
How does low-dimensional analyzers work?#
As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:
phase plane analysis;
codimension 1 or codimension 2 bifurcation analysis;
bifurcation analysis of the fast-slow system.
BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.
import brainpy as bp
import brainpy.math as bm
# bp.math.set_platform('cpu')
bp.math.enable_x64() # It's better to enable x64 when performing analysis
import numpy as np
import matplotlib.pyplot as plt
In this section, we provide a basic tutorial to understand how the brainpy.analysis.LowDimAnalyzer
works.
Terminology#
Given the FitzHugh-Nagumo model, we define an analyzer,
class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(FitzHughNagumoModel, self).__init__()
# parameters
self.a = 0.7
self.b = 0.8
self.tau = 12.5
# variables
self.V = bm.Variable(bm.zeros(1))
self.w = bm.Variable(bm.zeros(1))
self.Iext = bm.Variable(bm.zeros(1))
# functions
def dV(V, t, w, Iext=0.):
return V - V * V * V / 3 - w + Iext
def dw(w, t, V, a=0.7, b=0.8):
return (V + a - b * w) / self.tau
self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)
def update(self, tdi):
self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
self.w.value = self.int_w(self.w, tdi.t, self.V, self.a, self.b, tdi.dt)
self.Iext[:] = 0.
model = FitzHughNagumoModel()
analyzer = bp.analysis.PhasePlane2D(
[model.int_V, model.int_w],
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
resolutions={'V': 0.01, 'w': 0.01},
)
In this instance of brainpy.analysis.LowDimAnalyzer
, we use the following terminologies.
x_var and y_var are defined by the order of the user setting. If the user sets the “target_vars” as “{‘V’: …, ‘w’: …}”,
x_var
andy_var
will be “V” and “w” respectively. Otherwise, if “target_vars”=”{‘w’: …, ‘V’: …}”,x_var
andy_var
will be “w” and “V” respectively.
analyzer.x_var, analyzer.y_var
('V', 'w')
fx and fy are defined as differential equations of
x_var
andy_var
respectively, i.e.,
fx
is
def dV(V, t, w, Iext=0.):
return V - V * V * V / 3 - w + Iext
fy
is
def dw(w, t, V, a=0.7, b=0.8):
return (V + a - b * w) / self.tau
analyzer.F_fx, analyzer.F_fy
(JITTransform(target=f2,
num_of_vars=0),
JITTransform(target=f2,
num_of_vars=0))
int_x and int_y are defined as integral functions of the differential equations for
x_var
andy_var
respectively.
analyzer.F_int_x, analyzer.F_int_y
(functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x00000268A2599E50>),
functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x00000268A2599EE0>))
x_by_y_in_fx and y_by_x_in_fx: They denote that
x_var
andy_var
can be separated from each other in “fx” nullcline function. Specifically,x_by_y_in_fx
ory_by_x_in_fx
denotes \(x = F(y)\) or \(y = F(x)\) accoording to \(f_x=0\) equation. For example, in the above FitzHugh-Nagumo model, \(w\) can be easily represented by \(V\) when \(\mathrm{dV(V, t, w, I_{ext})} = 0\), i.e.,y_by_x_in_fx
is \(w= V - V ^3 / 3 + I_{ext}\).
Similarly, x_by_y_in_fy (\(x=F(y)\)) and y_by_x_in_fy (\(y=F(x)\)) denote
x_var
andy_var
can be separated from each other in “fy” nullcline function. For example, in the above FitzHugh-Nagumo model,y_by_x_in_fy
is \(w= \frac{V + a}{b}\), andx_by_y_in_fy
is \(V= b * w - a\).
x_by_y_in_fx
,y_by_x_in_fx
,x_by_y_in_fy
andy_by_x_in_fy
can be set in theoptions
argument.
Mechanism for 1D system analysis#
In order to understand the adavantages and disadvantages of BrainPy’s analysis toolkit, it is better to know the minimal mechanism how brainpy.analysis
works.
The automatic model analysis in BrainPy heavily relies on numerical optimization methods, including Brent’s method and BFGS method. For example, for the above one-dimensional system (\(\frac{dx}{dt} = \mathrm{sin}(x) + I\)), after the user sets the resolution to 0.001
, we will get the evaluation points according to the variable boundary [-10, 10]
.
bp.math.arange(-10, 10, 0.001)
Array([-10. , -9.999, -9.998, ..., 9.997, 9.998, 9.999], dtype=float64)
Then, BrainPy filters out the candidate intervals in which the roots lie in. Specifically, it tries to find all intervals like \([x_1, x_2]\) where \(f(x_1) * f(x_2) \le 0\) for the 1D system \(\frac{dx}{dt} = f(x)\).
For example, the following two points which have opposite signs are candidate points we want.
def plot_interval(x0, x1, f):
xs = np.linspace(x0, x1, 100)
plt.plot(xs, f(xs))
plt.scatter([x0, x1], f(np.asarray([x0, x1])), edgecolors='r')
plt.axhline(0)
plt.show()
plot_interval(-0.001, 0.001, lambda x: np.sin(x))

According to the intermediate value theorem, there must be a solution between \(x_1\) and \(x_2\) when \(f(x_1) * f(x_2) \le 0\).
Based on these candidate intervals, BrainPy uses Brent’s method to find roots \(f(x) = 0\). Further, after obtain the value of the root, BrainPy uses automatic differentiation to evaluate the stability of each root solution.
Overall, BrainPy’s analysis toolkit shows significant advantages and disadvantages.
Pros: BrainPy uses numerical methods to find roots and evaluate their stabilities, it does not case about how complex your function is. Therefore, it can apply to general problems, including any 1D and 2D dynamical systems, and some part of low-dimensional (\(\ge 3\)) dynamical systems (see later sections). Especially, BrainPy’s analysis toolkit is highly useful when the mathematical equations are too complex to get analytical solutions (the example please refer to the tutorial Anlysis of A Decision Making Model).
Cons: However, numerical methods used in BrainPy are hard to find fixed points only exist at a moment. Moreover, when resolution
is small, there will be large amount of calculating. Users should pay attention to designing suitable resolution settings.
Mechanism for 2D system analysis#
plot_vector_field()
Plotting vector field is simple. We just need to evaluate the values of each differential equation.
plot_nullcline()
Nullclines are evaluated through the Brent’s methods. In order to get all \((x, y)\) values that satisfy fx=0
(i.e., \(f_x(x, y) = 0\)), we first fix \(y=y_0\), then apply Brent optimization to get all \(x'\) that satisfy \(f_x(x', y_0) = 0\) (alternatively, we can fix \(x\) then optimize \(y\)). Therefore, we will perform Brent optimization many times, because we will iterate over all \(y\) value according to the resolution setting.
plot_fixed_points()
The fixed point finding in BrainPy relies on BFGS method. First, we define an auxiliary function \(L(x, t)\):
\(L(x, t)\) is always bigger than 0. We use BFGS optimization to get all local minima. Finally, we filter out the minima whose losses are smaller than \(1e^{-8}\), and we choose them as fixed points.
For this method, how to choose the initial points to perform optimization is the challege, especially when the parameter resolutions are small. Generally, there are four methods provided in BrainPy.
fx-nullcline: Choose the points in “fx” nullcline as the initial points for optimization.
fy-nullcline: Choose the points in “fy” nullcline as the initial points for optimization.
nullclines: Choose both the points in “fx” nullcline and “fy” nullcline as the initial points for optimization.
aux_rank: For a given set of parameters, we evaluate loss function at each point according to the resolution setting. Then we choose the first
num_rank
(default is 100) points which have the smallest losses.
However, if users provide one of functions of x_by_y_in_fx
, y_by_x_in_fx
, x_by_y_in_fy
and y_by_x_in_fy
. Things will become very simple, because we can change the 2D system as a 1D system, then we only need to optimzie the fixed points by using our favoriate Brent optimization.
For the given FitzHugh-Nagumo model, we can set
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars=dict(V=[-3, 3], w=[-3., 3.]),
target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
resolutions={'a': 0.01, 'Iext': 0.01},
options={bp.analysis.C.y_by_x_in_fy: (lambda V, a=0.7, b=0.8: (V + a) / b)}
)
analyzer.plot_bifurcation()
analyzer.show_figure()
I am making bifurcation analysis ...
I am trying to find fixed points by brentq optimization ...
I am trying to filter out duplicate fixed points ...
Found 5000 fixed points.


References#
[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.
[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.
[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.
Interoperation with other JAX frameworks#
BrainPy is designed to be easily interoperated with other JAX frameworks.
import jax
import brainpy as bp
# math library of BrainPy, JAX, NumPy
import brainpy.math as bm
import jax.numpy as jnp
import numpy as np
1. data are exchangeable among different frameworks.#
This can be realized because Array
can be direactly converted to JAX ndarray or NumPy ndarray.
Convert a Array
into a JAX ndarray.
b = bm.random.randint(10, size=5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Array.value is a JAX's DeviceArray
b.value
DeviceArray([9, 9, 0, 4, 7], dtype=int32)
Convert a Array
into a numpy ndarray.
# Array can be easily converted to a numpy ndarray
np.asarray(b)
array([9, 9, 0, 4, 7])
Convert a numpy ndarray into a Array
.
bm.asarray(np.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)
Convert a JAX ndarray into a Array
.
bm.asarray(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)
bm.Array(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)
2. transformations in brainpy.math
also work on functions.#
APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.
import optax
# First create several useful functions.
network = jax.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))
optimizer = optax.adam(learning_rate=1e-1)
def compute_loss(params, x, y):
y_pred = network(params, x)
loss = bm.mean(optax.l2_loss(y_pred, y))
return loss
@bm.jit
def train(params, opt_state, xs, ys):
grads = bm.grad(compute_loss)(params, xs.value, ys)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Generate some data
bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer
params = bm.array([0.0, 0.0])
opt_state = optimizer.init(params)
# A simple update loop
for _ in range(1000):
params, opt_state = train(params, opt_state, xs, ys)
assert bm.allclose(params, target_params), \
'Optimization should retrieve the target params used to generate the data.'
3. other JAX frameworks can be integrated into a BrainPy program.#
In this example, we use the Flax, a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy’s syntax.
Here, we first use flax to define a CNN network.
from flax import linen as nn
class CNN(nn.Module):
"""A CNN model implemented by using Flax."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
return x
Then, we define an RNN model by using our BrainPy interface.
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
class Network(bp.dyn.DynamicalSystem):
"""A network model implemented by BrainPy"""
def __init__(self):
super(Network, self).__init__()
# cnn and its parameters
self.cnn = CNN()
rng = bm.random.DEFAULT.split_key()
params = self.cnn.init(rng, jnp.ones([1, 4, 28, 1]))['params']
leaves, self.tree = tree_flatten(params)
self.implicit_vars.update(tree_map(bm.TrainVar, leaves))
# rnn
self.rnn = bp.layers.GRU(256, 100)
# readout
self.linear = bp.layers.Dense(100, 10)
def update(self, sha, x):
params = tree_unflatten(self.tree, [v.value for v in self.implicit_vars.values()])
x = self.cnn.apply({'params': params}, bm.as_jax(x))
x = self.rnn(sha, x)
x = self.linear(sha, x)
return x
We initialize the network, optimizer, loss function, and BP trainer.
net = Network()
opt = bp.optim.Momentum(0.1)
def loss_func(predictions, targets):
logits = bm.max(predictions, axis=1)
loss = bp.losses.cross_entropy_loss(logits, targets)
accuracy = bm.mean(bm.argmax(logits, -1) == targets)
return loss, {'accuracy': accuracy}
trainer = bp.train.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)
We get the MNIST dataset.
train_dataset = bp.datasets.MNIST(r'D:\data\mnist', train=True, download=True)
X = train_dataset.data.reshape((-1, 7, 4, 28, 1)) / 255
Y = train_dataset.targets
Finally, train our defined model by using BPTT.fit()
function.
trainer.fit([X, Y], batch_size=256, num_epoch=10)
Train 100 steps, use 32.5824 s, train loss 0.96465, accuracy 0.66015625
Train 200 steps, use 30.9035 s, train loss 0.38974, accuracy 0.89453125
Train 300 steps, use 33.1075 s, train loss 0.31525, accuracy 0.890625
Train 400 steps, use 31.4062 s, train loss 0.23846, accuracy 0.91015625
Train 500 steps, use 32.3371 s, train loss 0.21995, accuracy 0.9296875
Train 600 steps, use 32.5692 s, train loss 0.20885, accuracy 0.92578125
Train 700 steps, use 33.0139 s, train loss 0.24748, accuracy 0.90625
Train 800 steps, use 31.9635 s, train loss 0.14563, accuracy 0.953125
Train 900 steps, use 31.8845 s, train loss 0.17017, accuracy 0.94140625
Train 1000 steps, use 32.0537 s, train loss 0.09413, accuracy 0.95703125
Train 1100 steps, use 32.3714 s, train loss 0.06015, accuracy 0.984375
Train 1200 steps, use 31.6957 s, train loss 0.12061, accuracy 0.94921875
Train 1300 steps, use 31.8346 s, train loss 0.13908, accuracy 0.953125
Train 1400 steps, use 31.5252 s, train loss 0.10718, accuracy 0.953125
Train 1500 steps, use 31.7274 s, train loss 0.07869, accuracy 0.96875
Train 1600 steps, use 32.3928 s, train loss 0.08295, accuracy 0.96875
Train 1700 steps, use 31.7718 s, train loss 0.07569, accuracy 0.96484375
Train 1800 steps, use 31.9243 s, train loss 0.08607, accuracy 0.9609375
Train 1900 steps, use 32.2454 s, train loss 0.04332, accuracy 0.984375
Train 2000 steps, use 31.6231 s, train loss 0.02369, accuracy 0.9921875
Train 2100 steps, use 31.7800 s, train loss 0.03862, accuracy 0.9765625
Train 2200 steps, use 31.5431 s, train loss 0.01871, accuracy 0.9921875
Train 2300 steps, use 32.1064 s, train loss 0.03255, accuracy 0.9921875
Numerical Solvers for Ordinary Differential Equations#
Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to numerical methods for solving such differential equations.
This section will illustrate how to define ordinary differential quations (ODEs) and how to define the numerical integration methods for ODEs in BrainPy.
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
bm.set_platform('cpu')
bp.__version__
'2.3.0'
How to define ODE functions?#
BrainPy provides a convenient and intuitive way to define ODE systems. For the ODEs
we can define them in a Python function:
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
where t
denotes the current time, x
and y
passed before t
denote the dynamical variables, and p1
and p2
after t
denote the parameters needed in this system. In the function body, the derivative f1
and g1
can be customized by the user’s need. Finally, the corresponding derivatives dx
and dy
are returned in the same order as that of the variables in the function arguments.
For each variabl, it can be a scalar (var_type = bp.integrators.SCALAR_VAR
), a vector/matrix (var_type = bp.integrators.POP_VAR
), or a system (var_type = bp.integrators.SYSTEM_VAR
). The “system” means that the argument x
denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:
def diff(xy, t, p1, p2):
x, y = xy
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return bm.array([dx, dy])
How to define the numerical integration for ODEs?#
After the definition of ODE functions, it is very easy to define the numerical integration for these functions. We just need to put a decorator bp.odeint
above the ODE function.
@bp.odeint
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
After wrapping it by bp.odeint
, the function becomes an instance of ODEintegrator
.
isinstance(diff, bp.ode.ODEIntegrator)
True
bp.odeint
receives several arguments:
“method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.
diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x15abadeedc0>
“dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.
diff.dt
0.1
“show_code”: bool, to indicate whether to show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.
@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
x_new = x + dx_k1 * dt * 1
y_new = y + dy_k1 * dt * 1
return x_new, y_new
{'f': <function diff at 0x0000015ABADFBF70>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x15abae03550>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
k2_x_arg = x + dt * dx_k1 * 0.5
k2_y_arg = y + dt * dy_k1 * 0.5
k2_t_arg = t + dt * 0.5
dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
k3_x_arg = x + dt * dx_k2 * 0.5
k3_y_arg = y + dt * dy_k2 * 0.5
k3_t_arg = t + dt * 0.5
dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
k4_x_arg = x + dt * dx_k3
k4_y_arg = y + dt * dy_k3
k4_t_arg = t + dt
dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
return x_new, y_new
{'f': <function diff at 0x0000015ABAE023A0>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abae03d30>
Two Examples#
Example 1: FitzHugh–Nagumo model#
Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:
For this FHN model, we can code it in BrainPy like this:
@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
dw = (V + a - b * w) / tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,
a = 0.7; b = 0.8; tau = 12.5; Iext = 1.
the solution of the FHN model between 0 and 100 ms can be approximated by
hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
V, w = integral(V, w, t, Iext, a, b, tau)
hist_V.append(V)
plt.plot(hist_times, hist_V)
plt.show()

This manual loop in Python code is usually slow. In BrainPy, we provide a structural runner for integrators: brainpy.integrators.IntegratorRunner
, which can benefit from the JIT compilation.
runner = bp.IntegratorRunner(
integral,
monitors=['V'],
inits=dict(V=0., w=0.),
args=dict(a=a, b=b, tau=tau, Iext=Iext),
dt=0.01
)
runner.run(100.)
plt.plot(runner.mon.ts, runner.mon.V)
plt.show()

Example 2: Hodgkin–Huxley model#
Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h
) are used for modeling the initiation and propagation of the action potential. Specifically, they are governed by the following equations:
In BrainPy, such dynamical system can be coded as:
@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt, dmdt, dhdt, dndt
Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:
Iext = 10.; ENa = 50.; EK = -77.; EL = -54.387
C = 1.0; gNa = 120.; gK = 36.; gL = 0.03
runner = bp.IntegratorRunner(
integral,
monitors=list('Vmhn'),
inits=[0., 0., 0., 0.],
args=dict(Iext=Iext, gNa=gNa, ENa=ENa, gK=gK, EK=EK, gL=gL, EL=EL, C=C),
dt=0.01
)
runner.run(100.)
plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()

Provided ODE Numerical Solvers#
BrainPy
provides several types of numerical methods for ODEs, including explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler methods.
1. Explicit Runge-Kutta (RK) methods for ODEs#
The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:
Methods |
Keywords |
---|---|
Euler |
euler |
Midpoint |
midpoint |
Heun’s second-order method |
heun2 |
Ralston’s second-order method |
ralston2 |
RK2 |
rk2 |
RK3 |
rk3 |
RK4 |
rk4 |
Heun’s third-order method |
heun3 |
Ralston’s third-order method |
ralston3 |
Third-order Strong Stability Preserving Runge-Kutta |
ssprk3 |
Ralston’s fourth-order method |
ralston4 |
Runge-Kutta 3/8-rule fourth-order method |
rk4_38rule |
Users can utilize these methods by specifying the method
option in brainpy.odeint()
with their corresponding keyword. For example:
@bp.odeint(method='rk4')
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abc88be50>
Or, you can directly instance your favorite integrator:
@bp.ode.RK4
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abca4daf0>
def derivative(v, t, p):
# do something
return v
int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abc88bca0>
2. Adaptive Runge-Kutta (RK) methods for ODEs#
The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if \(error > tol\), then replace \(dt\) with \(dt_{new}\) and repeat the step. Therefore, adaptive RK methods allow a varied step size. In BrainPy, the following adaptive RK methods are provided in BrainPy:
Methods |
keywords |
---|---|
Runge–Kutta–Fehlberg 4(5) |
rkf45 |
Runge–Kutta–Fehlberg 1(2) |
rkf12 |
Dormand–Prince method |
rkdp |
Cash–Karp method |
ck |
Bogacki–Shampine method |
bs |
Heun–Euler method |
heun_euler |
In default, the above methods are not adaptive, unless users provide a keyword adaptive=True
in brainpy.odeint()
. When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt
will be appended in the functional arguments. Moreover, the tolerance tol
for stepsize adjustment can also be modified. Let’s take the Lorenz system as the example:
# adaptively adjust step-size
@bm.jit
@bp.odeint(method='rkf45',
adaptive=True, # active the "adaptive" option
tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
# should provide one more argument "dt" when using the adaptive rk method
x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)
hist_x.append(x.value)
hist_y.append(y.value)
hist_z.append(z.value)
hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()


3. Exponential Euler methods for ODEs#
Finally, BrainPy provides Exponential integrators for ODEs. For you ODE systems, we highly recommend you to use Exponential Euler methods. Exponential Euler method provided in BrainPy uses automatic differentiation to find linear part.
Methods |
keywords |
---|---|
Exponential Euler |
exp_euler |
Let’s take a linear system as the theoretical demonstration,
the exponential Euler schema is given by:
As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.
However, using Exponential Euler method requires us to write each derivative function separately. Otherwise, the automatic differentiation will lead to wrong results.
Interestingly, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. You will find that by using the Exponential Euler method, the numerical step can be greatly enlarged to save the computation time.
Iext=10.; ENa=50.; EK=-77.; EL=-54.387
C=1.0; gNa=120.; gK=36.; gL=0.03
def dm(m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt
def dh(h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt
def dn(n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
def dV(V, t, m, h, n, Iext):
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt
Although we define HH differential equations as separable functions, relying on brainpy.JointEq, we can numerically integrate these equations jointly.
hh_derivative = bp.JointEq([dV, dm, dh, dn])
def run(method, Iext=10., dt=0.1):
integral = bp.odeint(hh_derivative, method=method)
runner = bp.IntegratorRunner(
integral,
monitors=list('Vmhn'),
inits=[0., 0., 0., 0.],
args=dict(Iext=Iext),
dt=dt
)
runner.run(100.)
plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()
Euler Method: not able to complete the integral when the time step is a bit larger
run('euler', Iext=10, dt=0.02)

run('euler', Iext=10, dt=0.1)

RK4 Method: better than the Euler method, but still requires the times step to be small
run('rk4', Iext=10, dt=0.1)

run('rk4', Iext=10, dt=0.2)

Exponential Euler Method: allows larger time step and generates accurate results
run('exp_euler', Iext=10, dt=0.2)

Numerical Solvers for Stochastic Differential Equations#
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
import brainpy as bp
bp.__version__
'2.3.0'
import matplotlib.pyplot as plt
How to define SDE functions?#
For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by
where \(X_t = X(t)\) is the realization of a stochastic process or random variable, \(f(X_t, t)\) is the drift coefficient, \(g(X_t, t)\) denotes the diffusion coefficient, the stochastic process \(W_t\) is called Wiener process.
For this SDE system, we can define two Python funtions \(f\) and \(g\) to represent it.
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df
Same with the ODE functions, the arguments before \(t\) denotes the random variables, while the arguments defined after \(t\) represents the parameters. For the SDE function with scalar noise, the size of the return data \(dg\) and \(df\) should be the same. For example, \(df \in R^d, dg \in R^d\).
However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:
For such \(m\)-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of \(dg\) has one more dimension. For example, \(df \in R^{d}, dg \in R^{m \times d}\).
How to define the numerical integration for SDEs?#
Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get
In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “\(\circ\)” to distinct it from the former Ito integral.
The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as
In the stochastic integral of the Ito SDE, \(\lambda=0\), thus \(\tau_k=t_k\);
In the definition of the Stratonovich integral, \(\lambda=0.5\), thus \(\tau_k=(t_{k+1} + t_{k}) / 2\).
In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type
in decorator bp.sdeint
. intg_type
can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type
keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.
Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Or, it can be expressed as:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part, g=g_part, method='milstein')
However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(m, d)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part,
g=g_part,
method='milstein',
intg_type=bp.integrators.STRA_SDE,
wiener_type=bp.integrators.VECTOR_WIENER)
Example: Noisy Lorenz system#
Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:
sigma = 10; beta = 8/3;
rho = 28; p = 0.1
def lorenz_g(x, y, z, t):
return p * x, p * y, p * z
def lorenz_f(x, y, z, t):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
lorenz = bp.sdeint(f=lorenz_f,
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER)
To run this integrator, we use brainpy.integrators.IntegratorRunner, which can JIT compile the model to gain impressive speed.
runner = bp.IntegratorRunner(
lorenz,
monitors=['x', 'y', 'z'],
inits=[1., 1., 1.],
dt=0.001
)
runner.run(50.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()

We can also rewrite the above differential equation as a JointEq of separable equations, so that it can be applied to Exponential Euler method.
dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
lorenz_f = bp.JointEq(dx, dy, dz)
lorenz = bp.sdeint(f=lorenz_f,
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER,
method='exp_euler')
runner = bp.IntegratorRunner(
lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.], dt=0.001
)
runner.run(50.)
plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()

Supported SDE Numerical Methods#
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
Methods |
Keywords |
Ito SDE support |
Stratonovich SDE support |
Scalar Wiener support |
Vector Wiener support |
---|---|---|---|---|---|
Strong SRK scheme: SRI1W1 |
srk1w1_scalar |
Yes |
Yes |
||
Strong SRK scheme: SRI2W1 |
srk2w1_scalar |
Yes |
Yes |
||
Strong SRK scheme: KlPl |
KlPl_scalar |
Yes |
Yes |
||
Euler method |
euler |
Yes |
Yes |
Yes |
Yes |
Heun method |
heun |
Yes |
Yes |
Yes |
|
Milstein |
milstein |
Yes |
Yes |
Yes |
Yes |
Derivative-free Milstein |
milstein_grad_free |
Yes |
Yes |
Yes |
Yes |
Exponential Euler |
exp_euler |
Yes |
Yes |
Yes |
Numerical Solvers for Fractional Differential Equations#
import matplotlib.pyplot as plt
import brainpy as bp
import brainpy.math as bm
bp.__version__
'2.3.0'
Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:
Grünwald-Letnikov derivative
Caputo fractional derivative
See Fractional calculus - Wikipedia for more details.
Methods for Caputo FDEs#
For a given fractional differential equation
where the fractional order \(0<\alpha\le 1\). BrainPy provides two kinds of methods:
Euler method -
brainpy.fde.CaputoEuler
L1 schema integration -
brainpy.fde.CaputoL1Schema
brainpy.fde.CaputoEuler
#
brainpy.fed.CaputoEuler
provides one-step Euler method for integrating Caputo fractional differential equations.
Given a fractional-order Qi chaotic system
we can solve the equation system by:
a, b, c = 35, 8 / 3, 80
def qi_system(x, y, z, t):
dx = -a * x + a * y + y * z
dy = c * x - y - x * z
dz = -b * z + x * y
return dx, dy, dz
dt = 0.001
duration = 50
inits = [0.1, 0.2, 0.3]
# The numerical integration of FDE need to know all
# history information, therefore, we need provide
# the overall simulation time "num_step" to save
# all history values.
integrator = bp.fde.CaputoEuler(qi_system,
alpha=0.98, # fractional order
num_memory=int(duration / dt),
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

brainpy.fde.CaputoL1Schema
#
brainpy.fed.CaputoL1Schema
is another commonly used method to integrate Caputo derivative equations. Let’s try it with a fractional-order Lorenz system, which is given by:
a, b, c = 10, 28, 8 / 3
def lorenz_system(x, y, z, t):
dx = a * (y - x)
dy = x * (b - z) - y
dz = x * y - c * z
return dx, dy, dz
dt = 0.001
duration = 50
inits = [1, 2, 3]
integrator = bp.fde.CaputoL1Schema(lorenz_system,
alpha=0.99, # fractional order
num_memory=int(duration / dt),
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

Methods for Grünwald-Letnikov FDEs#
Grünwald-Letnikov FDE is another commonly-used type in neuroscience. Here, we provide a efficient computation method according to the short-memory principle in Grünwald-Letnikov method.
brainpy.fde.GLShortMemory
#
brainpy.fde.GLShortMemory
is highly efficient, because it does not require infinity memory length for numerical solution. Due to the decay property of the coefficients, brainpy.fde.GLShortMemory
implements a limited memory length to reduce the computational time. Specifically, it only relies on the memory window of num_memory
length. With the increasing width of memory window, the accuracy of numerical approximation will increase.
Here, we demonstrate it by using a fractional-order Chua system, which is defined as
a, b, c = 10.725, 10.593, 0.268
m0, m1 = -1.1726, -0.7872
def chua_system(x, y, z, t):
f = m1 * x + 0.5 * (m0 - m1) * (abs(x + 1) - abs(x - 1))
dx = a * (y - x - f)
dy = x - y + z
dz = -b * y - c * z
return dx, dy, dz
dt = 0.001
duration = 200
inits = [0.2, -0.1, 0.1]
integrator = bp.fde.GLShortMemory(chua_system,
alpha=[0.93, 0.99, 0.92],
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.y, runner.mon.z)
plt.show()

Actually, the coefficient used in brainpy.fde.GLWithMemory
can be inspected through:
plt.figure(figsize=(10, 6))
coef = integrator.binomial_coef
alphas = bm.as_numpy(integrator.alpha)
plt.subplot(211)
for i in range(3):
plt.plot(coef[:, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.subplot(212)
for i in range(3):
plt.plot(coef[:10, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.show()

As you see, the coefficients decay very quickly!
Further reading#
More examples of how to use numerical solvers of fractional differential equations defined in BrainPy, please see:
Numerical Solvers for Delay Differential Equations#
In real world systems, delay is very often encountered in many practical systems, such as automatic control, biology, economics and long transmission lines. The delayed differential equation (DDEs) is used to describe these dynamical systems.
Delay differential equations (DDEs) are a type of differential equation in which the derivative at a certain time is given in terms of the values of the function at previous times.
Let’s take delay ODEs as the example. The simplest constant delay equations have the form
where the time delays (lags) \(\tau_j\) are positive constants.
For neutral type DDE delays appear in derivative terms,
More generally, state dependent delays may depend on the solution, that is \(\tau_i = \tau_i (t,y(t))\).
In BrainPy, we support delay differential equations based on delay variables. Specifically, for state-dependent delays, we have:
brainpy.math.TimeDelay
brainpy.math.LengthDelay
For neutral-type delays, we use:
brainpy.math.NeuTimeDelay
brainpy.math.NeuLenDelay
import matplotlib.pyplot as plt
import brainpy as bp
import brainpy.math as bm
bm.enable_x64()
bp.__version__
'2.3.0'
Delay variables#
For an ODE system, the numerical methods need to know its initial condition \(y(t_0)=y_0\) and its derivative rule \(y'(t_0) = y'_0\). However, for DDEs, it is not enough to give a set of initial values for the function and its derivatives at \(t_0\), but one must give a set of functions to provide the historical values for \(t_0 - max(\tau) \leq t \leq t_0\).
Therefore, you need some delay variables to wrap the variable delays. brainpy.math.TimeDelay
can be used to define delay variables which depend on states, and brainpy.math.NeuTimeDelay
is used to define delay variables which depend on the derivative.
d = bm.TimeDelay(bm.zeros(2), delay_len=10, dt=1, t0=0, before_t0=lambda t: t)
d(0.)
Array([0., 0.], dtype=float64)
d(-0.5)
Array([-0.5, -0.5], dtype=float64)
Request a time beyond \((max\_delay, t_0)\) will cause an error.
try:
d(0.1)
except Exception as e:
print(e)
ERROR:jax.experimental.host_callback:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x0000016DB6F87280> threw exception The request time should be less than the current time 0.0. But we got 0.1 > 0.0.
INTERNAL: Generated function failed: CpuCallback error: ValueError: The request time should be less than the current time 0.0. But we got 0.1 > 0.0
At:
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\math\delayvars.py(202): _check_time1
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\tools\errors.py(27): <lambda>
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(721): __call__
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(1295): _outside_call_run_callback
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(1164): wrapped_callback
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\mlir.py(1765): _wrapped_callback
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\mlir.py(1790): _wrapped_callback
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\pxla.py(2136): __call__
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\profiler.py(314): wrapper
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\dispatch.py(120): apply_primitive
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(712): process_primitive
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(332): bind_with_trace
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(2449): bind
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(780): cond_bind
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(255): _cond
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(274): cond
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\traceback_util.py(162): reraise_with_filtered_traceback
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\tools\errors.py(29): jit_error_checking
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\math\delayvars.py(216): __call__
C:\Users\adadu\AppData\Local\Temp\ipykernel_35084\1658324099.py(2): <cell line: 1>
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3398): run_code
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3338): run_ast_nodes
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3135): run_cell_async
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\async_helpers.py(129): _pseudo_sync_runner
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(2936): _run_cell
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(2881): run_cell
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\zmqshell.py(528): run_cell
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\ipkernel.py(383): do_execute
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(730): execute_request
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(406): dispatch_shell
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(499): process_one
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(510): dispatch_queue
C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\events.py(80): _run
C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\base_events.py(1905): _run_once
C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\base_events.py(601): run_forever
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\tornado\platform\asyncio.py(199): start
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelapp.py(712): start
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\traitlets\config\application.py(846): launch_instance
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel_launcher.py(17): <module>
C:\Users\adadu\miniconda3\envs\brainpy\lib\runpy.py(87): _run_code
C:\Users\adadu\miniconda3\envs\brainpy\lib\runpy.py(197): _run_module_as_main
Delay ODEs#
Here we illustrate how to make numerical integration of delay ODEs with several examples. Before that, we define a general function to simulate a delay ODE function.
def delay_odeint(duration, eq, args=None, inits=None,
state_delays=None, neutral_delays=None,
monitors=('x',), method='euler', dt=0.1):
# define integrators of ODEs based on `brainpy.odeint`
dde = bp.odeint(eq,
state_delays=state_delays,
neutral_delays=neutral_delays,
method=method)
# define IntegratorRunner
runner = bp.IntegratorRunner(dde,
args=args,
monitors=monitors,
dt=dt,
inits=inits)
runner.run(duration)
return runner.mon
Example #1: First-order DDE with one constant delay and a constant initial history function#
Let the following DDE be given:
where the delay is 1 s. the example compares the solutions of three different cases using three different constant history functions:
Case #1: \(\phi(t)=-1\)
Case #2: \(\phi(t)=0\)
Cas3 #3: \(\phi(t)=1\)
def equation(x, t, xdelay):
return -xdelay(t - 1)
case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round')
case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=0., interp_method='round')
case3_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=1., interp_method='round')
case1 = delay_odeint(20., equation, args={'xdelay': case1_delay},
state_delays={'x': case1_delay}) # delay for variable "x"
case2 = delay_odeint(20., equation, args={'xdelay': case2_delay}, state_delays={'x': case2_delay})
case3 = delay_odeint(20., equation, args={'xdelay': case3_delay}, state_delays={'x': case3_delay})
fig, axs = plt.subplots(3, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1)$")
axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')
axs[1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=0$')
axs[2].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[2].set_title('$ihf(t)=1$')
plt.show()

Example #2: First-order DDE with one constant delay and a non constant initial history function#
Let the following DDE be given:
where the delay is 2 s; the example compares the solutions of four different cases using two different non constant history functions and two different intervals of \(t\):
Case #1: \(\phi(t)=e^{-t} - 1, t \in [0, 4]\)
Case #2: \(\phi(t)=e^{t} - 1, t \in [0, 4]\)
Case #3: \(\phi(t)=e^{-t} - 1, t \in [0, 60]\)
Case #4: \(\phi(t)=e^{t} - 1, t \in [0, 60]\)
def eq(x, t, xdelay):
return -xdelay(t - 2)
delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t) - 1, dt=0.01, interp_method='round')
delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t) - 1, dt=0.01, interp_method='round')
delay3 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t) - 1, dt=0.01, interp_method='round')
delay4 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t) - 1, dt=0.01, interp_method='round')
case1 = delay_odeint(4., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(4., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
case3 = delay_odeint(60., eq, args={'xdelay': delay3}, state_delays={'x': delay3}, dt=0.01)
case4 = delay_odeint(60., eq, args={'xdelay': delay4}, state_delays={'x': delay4}, dt=0.01)
fig, axs = plt.subplots(2, 2)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-2)$")
axs[0, 0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 4]$')
axs[0, 1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[0, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 4]$')
axs[1, 0].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[1, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 60]$')
axs[1, 1].plot(case4.ts, case4.x, color='red', linewidth=1)
axs[1, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 60]$')
plt.show()

Example #3: First-order DDE with two constant delays and a constant initial history function#
Let the following DDE be given:
where the delays are two and are both constants equal to 1s and 2s respectively; The initial historical function is also constant and is \(\phi(t)=1\).
def eq(x, t):
return -delay(t - 1) + 0.3 * delay(t - 2)
delay = bm.TimeDelay(bm.ones(1), 2., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(10., eq, inits=[1.], state_delays={'x': delay}, dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1) + 0.3\ y(t-2)$")
axs.plot(mon.ts, mon.x, color='red', linewidth=1)
axs.set_title('$ihf(t)=1$')
plt.show()

Example #4: System of two first-order DDEs with one constant delay and two constant initial history functions#
Let the following system of DDEs be given:
where the delay is only one, constant and equal to 0.5 s and the initial historical functions are also constant; for what we said at the beginning of the post these must be two, in fact being the order of the system of first degree you need one for each unknown and they are: \(y_1(t)=1, y_2(t)=-1\).
def eq(x, y, t):
dx = x * ydelay(t - 0.5)
dy = y * xdelay(t - 0.5)
return dx, dy
xdelay = bm.TimeDelay(bm.ones(1), 0.5, before_t0=1., dt=0.01, interp_method='round')
ydelay = bm.TimeDelay(-bm.ones(1), 0.5, before_t0=-1., dt=0.01, interp_method='round')
mon = delay_odeint(3., eq, inits=[1., -1], state_delays={'x': xdelay, 'y': ydelay},
dt=0.01, monitors=['x', 'y'])
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$x'(t)=x(t) y(t-d); y'(t)=y(t) x(t-d)$")
axs.plot(mon.ts, mon.x.flatten(), color='red', linewidth=1)
axs.plot(mon.ts, mon.y.flatten(), color='blue', linewidth=1)
axs.set_title('$ihf_x(t)=1; ihf_y(t)=-1; d=0.5$')
plt.show()

Example #5: Second-order DDE with one constant delay and two constant initial history functions#
Let the following DDE be given:
where the delay is only one, constant and equal to 1 s. Since the DDE is second order, in that the second derivative of the unknown function appears, the historical functions must be two, one to give the values of the unknown \(y(t)\) for \(t <= 0\), and one and one to provide the value of the first derivative \(y'(t)\) also for \(t <= 0\).
In this example they are the following two constant functions: \(y(t)=1, y'(t)=0\).
Due to the properties of the second-order equations, the given DDE is equivalent to the following system of first-order equations:
and so the implementation falls into the case of the previous example of systems of first-order equations.
def eq(x, y, t):
dx = y
dy = -y - 2 * x - 0.5 * xdelay(t - 1)
return dx, dy
xdelay = bm.TimeDelay(bm.ones(1), 1., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(16., eq, inits=[1., 0.], state_delays={'x': xdelay}, monitors=['x', 'y'], dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y''(t)=-y'(t) - 2 y(t) - 0.5 y(t-1)$")
axs.plot(mon.ts, mon.x[:, 0], color='red', linewidth=1)
axs.plot(mon.ts, mon.y[:, 0], color='green', linewidth=1)
axs.set_title('$ih \, f_y(t)=1; ihf\,dy/dt(t)=0$')
plt.show()

Example #6: First-order DDE with one non constant delay and a constant initial history function#
Let the following DDE be given:
where the delay is not constant and is given by the function \(\mathrm{delay}(y, t)=|\frac{1}{10} t y(\frac{1}{10} t)|\), the example compares the solutions of two different cases using two different constant history functions:
Case #1: \(\phi(t)=-1\)
Case #2: \(\phi(t)=1\)
def eq(x, t, xdelay):
delay = abs(t * xdelay(t - 0.9 * t) / 10) # a tensor with (1,)
delay = delay[0]
return xdelay(t - delay)
Note
Note here we do not kwon the maximum lenght of the delay. Therefore, we can declare a fixed length delay variable with the delay_len
equal to or even bigger than the running duration.
delay1 = bm.TimeDelay(bm.ones(1), 30., before_t0=-1, dt=0.01)
delay2 = bm.TimeDelay(-bm.ones(1), 30., before_t0=1, dt=0.01)
case1 = delay_odeint(30., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(30., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
fig, axs = plt.subplots(2, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=y(t-delay(y, t))$")
axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')
axs[1].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=1$')
plt.show()

Delay SDEs#
Save as delay ODEs, state-dependent delay variables can be appended into state_delay
argument in brainpy.sdeint
function.
delay = bm.TimeDelay(bm.zeros(1),
2.,
before_t0=lambda t: bm.exp(-t) - 1,
dt=0.01,
interp_method='round')
f = lambda x, t: -delay(t - 2)
g = lambda x, t, *args: 0.01
dt = 0.01
integral = bp.sdeint(f, g, state_delays={'x': delay})
runner = bp.IntegratorRunner(integral,
monitors=['x'],
dt=dt)
runner.run(100.)
plt.plot(runner.mon.ts, runner.mon.x)
plt.show()

Delay FDEs#
Fractional order delayed differential equations as the generalization of the delayed differential equations, provide more freedom when we’re describing these systems, let’s see how we can use BrainPy to accelerate the simulation of fractional order delayed differential equations.
The fractional delayed differential equations has the general form:
Lemmings’ population cycle#
The fractional order version of the four-year life cycle of a population of lemmings is given by
\begin{split} \begin{gathered} D_{t}^{\alpha} y(t)=3.5 y(t)\left(1-\frac{y(t-0.74)}{19}\right), \ y(0)=19.00001 \ y(t)=19, t<0 \end{gathered} \end{split}
dt = 0.05
delay = bm.TimeDelay(bm.asarray([19.00001]), 0.74, before_t0=19., dt=dt)
f = lambda y, t: 3.5 * y * (1 - delay(t - 0.74) / 19)
integral = bp.fde.GLShortMemory(f,
alpha=0.97,
inits=[19.00001],
num_memory=500,
state_delays={'y': delay})
runner = bp.IntegratorRunner(integral,
inits=bm.asarray([19.00001]),
monitors=['y'],
fun_monitors={'y(t-0.74)': lambda t, _: delay(t - 0.74)},
dyn_vars=delay.vars(),
dt=dt)
runner.run(100.)
plt.plot(runner.mon['y'], runner.mon['y(t-0.74)'])
plt.xlabel('y(t)')
plt.ylabel('y(t-0.74)')
plt.show()

Time delay Chen system#
Time delay Chen system as a famous chaotic system with time delay, has important applications in many fields.
dt = 0.001
tau = 0.009
xdelay = bm.TimeDelay(bm.asarray([0.2]), tau, dt=dt)
zdelay = bm.TimeDelay(bm.asarray([0.5]), tau, dt=dt)
def derivative(x, y, z, t):
a = 35;
b = 3;
c = 27
dx = a * (y - xdelay(t - tau))
dy = (c - a) * xdelay(t - tau) - x * z + c * y
dz = x * y - b * zdelay(t - tau)
return dx, dy, dz
integral = bp.fde.GLShortMemory(derivative,
alpha=0.94,
inits=[0.2, 0., 0.5],
num_memory=500,
state_delays={'x': xdelay, 'z': zdelay})
runner = bp.IntegratorRunner(integral,
inits=[0.2, 0., 0.5],
monitors=['x', 'y', 'z'],
dyn_vars=xdelay.vars() + zdelay.vars(),
dt=dt)
runner.run(100.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
plt.show()

Enzyme kinetics#
Let’s see a more complex example of the fractional order version of enzyme kinetics with an inhibitor molecule:
dt = 0.01
tau = 4.
delay = bm.TimeDelay(bm.asarray([20.]), tau, before_t0=20, dt=dt)
def derivative(a, b, c, d, t):
da = 10.5 - a / (1 + 0.0005 * delay(t - tau) ** 3)
db = a / (1 + 0.0005 * delay(t - tau) ** 3) - b
dc = b - c
dd = c - 0.5 * d
return da, db, dc, dd
integral = bp.fde.GLShortMemory(derivative,
alpha=0.95,
inits=[60, 10, 10, 20],
num_memory=500,
state_delays={'d': delay})
runner = bp.IntegratorRunner(integral,
inits=[60, 10, 10, 20],
monitors=list('abcd'),
dyn_vars=delay.vars(),
dt=dt)
runner.run(200.)
plt.plot(runner.mon.ts, runner.mon.a, label='a')
plt.plot(runner.mon.ts, runner.mon.b, label='b')
plt.plot(runner.mon.ts, runner.mon.c, label='c')
plt.plot(runner.mon.ts, runner.mon.d, label='d')
plt.legend()
plt.xlabel('Time [ms]')
plt.show()

Fractional matrix delayed differential equations#
BrainPy is also capable of solving fractional matrix delayed differential equations:
Here \(x(t)\) is vector of states of the system, \(c(t)\) is a known function of disturbance.
We explain the detailed usage by using an example:
With initial condition:
dt = 0.01
tau = 3.1416
f = lambda t: bm.asarray([bm.sin(t) * bm.cos(t),
bm.sin(t) * bm.cos(t),
bm.cos(t) ** 2 - bm.sin(t) ** 2,
bm.cos(t) ** 2 - bm.sin(t) ** 2])
delay = bm.TimeDelay(f(0.), tau, before_t0=f, dt=dt)
A = bm.asarray([[0, 0, 1, 0], [0, 0, 0, 1], [0, -2, 0, 0], [-2, 0, 0, 0]])
B = bm.asarray([[0, 0, 0, 0], [0, 0, 0, 0], [-2, 0, 0, 0], [0, -2, 0, 0]])
c = bm.asarray([0, 0, 0, 0])
derivative = lambda x, t: A @ x + B @ delay(t - tau) + c
integral = bp.fde.GLShortMemory(derivative,
alpha=0.4,
inits=[f(0.)],
num_memory=500,
state_delays={'x': delay})
runner = bp.IntegratorRunner(integral,
inits=[f(0.)],
monitors=['x'],
dyn_vars=delay.vars(),
dt=dt)
runner.run(100.)
plt.plot(runner.mon.x[:, 0], runner.mon.x[:, 2])
plt.xlabel('x0')
plt.ylabel('x2')
plt.show()

Acknowledgement#
This tutorial is highly inspired from the work of Ettore Messina [1] and of Qingyu Qu [2].
[1] Computational Mindset by Ettore Messina, Solving delay differential equations using numerical methods in Python
Joint Differential Equations#
In a dynamical system, there may be multiple variables that change dynamically over time. Sometimes these variables are interconnected, and updating one variable requires others as the input. For example, in the widely known Hodgkin–Huxley model, the variables \(V\), \(m\), \(h\), and \(n\) are updated synchronously and interdependently (please refer to Building Neuron Modelsfor details). To achieve higher integral accuracy, it is recommended to use brainpy.JointEq
to jointly solving interconnected differential equations.
import brainpy as bp
brainpy.JointEq
#
brainpy.JointEq
is used to merge individual but interconnected differential equations into a single joint equation. For example, below are the two differential equations of the Izhikevich model:
a, b = 0.02, 0.20
dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
du = lambda u, t, V: a * (b * V - u)
Where updating \(V\) requires \(u\) as the input, and updating \(u\) requires \(V\) as the input. The joint equation can be defined as:
joint_eq = bp.JointEq(dV, du)
brainpy.JointEq
receives only one argument named eqs
, which can be a list or tuple containing multiple differential equations. Then it can be packed into a numarical integrator that solves the equation with a specified method, just as what can be done to any individual differential equation.
itg = bp.odeint(joint_eq, method='rk2')
There are several requirements for defining a joint equation:
Every individual differential equation should follow the format of defining a ODE or SDE funtion in BrainPy. For example, the arguments before
t
denote the dynamical variables and arguments aftert
denote the parameters.The same variable in different equations should have the same name. Different variables should named differently.
Note that brainpy.JointEq
supports make nested JointEq
, which means the instance of JointEq
can be an element to compose a new JointEq
.
Why use brainpy.JointEq
?#
Users may be confused with the function of brainpy.JointEq
, because multiple differential equations can be written in a single function:
def diff(V, u, t, Iext):
dV = 0.04 * V * V + 5 * V + 140 - u + Iext
du = a * (b * V - u)
return dV, du
itg_V_u = bp.odeint(diff, method='rk2')
or simply packed into interators separately:
int_V = bp.odeint(dV, method='rk2')
int_u = bp.odeint(du, method='rk2')
To illusrate the difference between joint and separate differential equations, let’s dive into the differential codes of these two types of equations.
If we make numerical solver for each derivative function, they will be solved independently:
import brainpy as bp
bp.odeint(dV, method='rk2', show_code=True)
def brainpy_itg_of_ode6(V, t, u, Iext, dt=0.1):
dV_k1 = f(V, t, u, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
return V_new
{'f': <function <lambda> at 0x0000022948DD6A60>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x229660543a0>
As is shown in the output code, the variable \(V\) is integrated twice by the RK2 method. For the second differential value dV_k2
, the updated value of \(V\) (k2_V_arg
) and original \(u\) are used to calculate the differential value. This will generate a tiny error, since the values of \(V\) and \(u\) are taken at different times.
To eliminate this error, the differential equation of \(V\) and \(u\) should be solved jointly through brainpy.JointEq
:
eq = bp.JointEq(eqs=(dV, du))
bp.odeint(eq, method='rk2', show_code=True)
def brainpy_itg_of_ode12_joint_eq(V, u, t, Iext, dt=0.1):
dV_k1, du_k1 = f(V, u, t, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_u_arg = u + dt * du_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
return V_new, u_new
{'f': <brainpy.integrators.joint_eq.JointEq object at 0x0000022967EC0C40>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x22967ec0160>
It is shown in this output code that second differential values of \(v\) and \(u\) are calculated by using the updated values (k2_V_arg
and k2_u_arg
) at the same time. This will result in a more accurate integral.
The figure below compares the simulation results of the Izhikevich model using joint and separate differential equations (\(dt = 0.2 ms\)). It is shown that as the simulation time increases, the integral error becomes greater.

Synaptic Connections#
Synaptic connections is an essential part for building a neural dynamic system. BrainPy provides several commonly used connection methods in the brainpy.connect module (which can be accessed by the shortcut bp.conn
) that can help users to easily construct many types of synaptic connection, inclulding built-in and self-customized connectors.
An Overview of BrainPy Connectors#
Here we provide an overview of BrainPy connectors.
Base class: bp.conn.Connector
#
The base class of connectors is brainpy.connect.Connector
. All connectors, built-in or customized, should inherit from the Connector class.
Two subclasses: TwoEndConnector
and OneEndConnector
#
There are two classes inheriting from the base class bp.conn.Connector
:
bp.conn.TwoEndConnector: a connector to build synaptic connections between two neuron groups.
bp.conn.OneEndConnector: a connector to build synaptic connections within a population of neurons.
Users can click the link of each class above to look through the API documentation.
Connector.__init__()#
All connectors need to be initialized first. For each built-in connector, users need to pass in the corresponding parameters for initialization. For details, please see the specific conector type below.
Connector.__call__()#
After initialization, users should call the connector and pass in parameters depending on specific connection types:
TwoEndConnector
: It has two input parameterspre_size
andpost_size
, each representing the size of the pre- and post-synaptic neuron group. It will result in a connection matrix with the shape of (pre_num, post_num).OneEndConnector
: It has only one parameterpre_size
which represent the size of the neuron group. It will result in a connection matrix with the shape of (pre_num, pre_num).
The __call__
function returns the class itself.
Connector.build_conn()#
Users can customize the connection in build_conn()
function. Notice there are three connection types users can provide:
Connection Types |
Definition |
---|---|
‘mat’ |
Dense conncetion, including a connection matrix. |
‘ij’ |
Index projection, including a pre-neuron index vector and a post-neuron index vector. |
‘csr’ |
Sparse connection, including a index vector and a indptr vector. |
Return type can be either a dict
or a tuple
. Here are two examples of how to return your connection data:
Example 1:
def build_conn(self):
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return dict(csr=(ind, indptr), mat=None, ij=None)
Example 2:
def build_conn(self):
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return 'csr', (ind, indptr)
After creating the synaptic connection, users can use the require()
method to access some useful properties of the connection.
Connector.require()#
This method returns the connection properties required by users. The connection properties are elaborated in the following sections in detail. Here is a brief summary of the connection properties users can require.
Connection properties |
Structure |
Definition |
---|---|---|
|
2-D array (matrix) |
Dense connection matrix |
|
1-D array (vector) |
Indices of the pre-synaptic neuron group |
|
1-D array (vector) |
Indices of the post-synaptic neuron group |
|
tuple (vector, vector) |
The post-synaptic neuron indices and the corresponding pre-synaptic neuron pointers |
|
tuple (vector, vector) |
The pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers |
|
tuple (vector, vector) |
The synapse indices sorted by pre-synaptic neurons and corresponding pre-synaptic neuron pointers |
|
tuple (vector, vector) |
The synapse indices sorted by post-synaptic neurons and corresponding post-synaptic neuron pointers |
Users can implement this method by following sentence:
pre_ids, post_ids, pre2post, conn_mat = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
Note
Note that this method can return multiple connection properties.
Connection Properties#
There are multiple connection properties that can be required by users.
1. conn_mat
#
The matrix-based synaptic connection is one of the most intuitive ways to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat')
. Each connection matrix is an array with the shape of \((n_{pre}, n_{post})\):

2. pre_ids
and post_ids
#
Using vectors to store the connection between neuron groups is a much more efficient way to reduce memory when the connection matrix is sparse. For the connction matrix conn_mat
defined above, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids.

In this way, we only need two vectors (pre_ids
and post_ids
) to store the synaptic connection. syn_id
in the figure indicates the indices of each neuron pair, i.e. each synapse.
3. pre2post
and post2pre
#
Another two synaptic structures are pre2post
and post2pre
. They establish the mapping between the pre- and post-synaptic neurons.
pre2post
is a tuple containing two vectors, one of which is the post-synaptic neuron indices and the other is the corresponding pre-synaptic neuron pointers. For example, the following figure shows the indices of the pre-synaptic neurons and the post-synaptic neurons to which the pre-synaptic neurons project:

To record the connection, firstly the post_ids are concatenated as a single vector call the post-synaptic index vector (indices). Because the post-synaptic neuron indices have been sorted by the pre-synaptic neuron indices, it is sufficient to record only the starting position of each pre-synaptic neuron index. Therefore, the pre-synaptic neuron indices and the end of the last pre-synaptic neuron index together make up the pre-synaptic index pointer vector (indptr), which is illustrated in the figure below.

The post-synaptic neuron indices to which pre-synaptic neuron \(i\) projects can be obtained by array slicing:
indices[indptr[i], indptr[i+1]]
Similarly, post2pre
is a 2-element tuple containing the pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers. Taking the connection in the illutration aobve as an example, the post-synaptic neuron indices and the pre-synaptic neuron indices to which the post-synaptic neurons project is shown as:

The pre-synaptic index vector (indices) and the post-synaptic index pointer vector (indptr) are listed below:

When the connection is sparse, pre2post
(or post2pre
) is a very efficient way to store the connection, since the lengths of the two vectors in the tuple are \(n_{synapse}\) and \(n_{pre}\) (\(n_{post}\)), respectively.
4. pre2syn
and post2syn
#
The last two properties are pre2syn
and post2syn
that record pre- and post-synaptic projection, respectively.
For pre2syn
, similar to pre2post
and post2pre
, there is a synapse index vector and a pre-synaptic index pointer vector that refers to the starting position of each pre-synaptic neuron index at the synapse index vector.
Below is the same example identifying the connection by pre-synaptic neuron indices and the synapses belonging to them.

For better understanding, The synapse indices, pre- and post-synaptic neuron indices are shown as below:

The pre-synaptic index pointer vector is computed in the same way as in pre2post
:

Similarly, post2syn
is a also tuple containing the synapse neuron indices and the corresponding post-synaptic neuron pointers.

The only different from pre2syn
is that the synapse indices is (most of the time) originally sorted by pre-synaptic neurons, but when computing post2syn
, synapses should be sorted by post-synaptic neuron indices:

The synapse index vector (the first row) and the post-synaptic index pointer vector (the last row) are listed below:

import brainpy as bp
import brainpy.math as bm
# bp.math.set_platform('cpu')
bp.__version__
'2.3.0'
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
Built-in regular connections#
brainpy.connect.One2One#
The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
conn(pre_size=10, post_size=10)
One2One
where pre_size
denotes the size of the pre-synaptic neuron group, post_size
denotes the size of the post-synaptic neuron group.
Note that parameter size
can be int, tuple of int or list of int where each element represent each dimension of neuron group.
In One2One connection, particularly, pre_size
and post_size
must be the same.
Class One2One
is inherited from TwoEndConnector
. Users can use method require
or requires
to get specific connection properties.
Here is an example:
size = 5
conn = bp.connect.One2One()(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
pre_ids: Array([0, 1, 2, 3, 4], dtype=int32)
post_ids: Array([0, 1, 2, 3, 4], dtype=int32)
pre2post: (Array([0, 1, 2, 3, 4], dtype=int32), Array([0, 1, 2, 3, 4, 5], dtype=int32))
brainpy.connect.All2All#
All neurons of the post-synaptic population form connections with all
neurons of the pre-synaptic population (dense connectivity). Users can
choose whether connect the neurons at the same position
(include_self=True or False
).

conn = bp.connect.All2All(include_self=False)
conn(pre_size=size, post_size=size)
All2All(include_self=False)
Class All2All
is inherited from TwoEndConnector
. Users can use method require
or requires
to get specific connection properties.
Here is an example:
conn = bp.connect.All2All(include_self=False)(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
print('conn_mat:', res[3])
pre_ids: Array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], dtype=int32)
post_ids: Array([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], dtype=int32)
pre2post: (Array([1, 2, 3, 4, 0, 2, 3, 4, 4, 3, 1, 0, 0, 1, 2, 4, 0, 1, 2, 3], dtype=int32), Array([ 0, 4, 8, 12, 16, 20], dtype=int32))
conn_mat: Array([[False, True, True, True, True],
[ True, False, True, True, True],
[ True, True, False, True, True],
[ True, True, True, False, True],
[ True, True, True, True, False]], dtype=bool)
brainpy.connect.GridFour#
GridFour
is the four nearest neighbors connection. Each neuron connect to its
nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
conn(pre_size=size)
GridFour(include_self=False, periodic_boundary=False)
Class GridFour
is inherited from OneEndConnector
, therefore there is only one parameter pre_size
representing the size of neuron group, which should be two-dimensional geometry.
Here is an example:
size = (4, 4)
conn = bp.connect.GridFour(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')
print('pre_ids', res[0])
pre_ids Array([ 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5,
6, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10,
10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15], dtype=int32)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.GridEight#
GridEight
is eight nearest neighbors connection. Each neuron connect to its
nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
conn(pre_size=size)
GridEight(N=1, include_self=False, periodic_boundary=False)
Class GridEight
is inherited from GridN
, which will be introduced as followed.
Here is an example:
size = (4, 4)
conn = bp.connect.GridEight(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')
print('pre_ids', res[0])
pre_ids Array([ 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 4,
4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6,
6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9,
9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11,
12, 12, 12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15], dtype=int32)
Take the central point (id = 4) as an example, its neighbors are all the other point except itself.
Therefore, its row in conn_mat
has True
for all values except itself.
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.GridN#
GridN
is also a nearest neighbors connection. Each neuron connect to its
nearest \((2N+1) \cdot (2N+1)\) neurons (if including itself).

Here are some examples to fully understand GridN
. It is slightly different from GridEight
: GridEight
is equivalent to GridN
when N = 1.
When N = 1: \(\begin{bmatrix} x & x & x\\ x & I & x\\ x & x & x \end{bmatrix}\)
When N = 2: \( \begin{bmatrix} x & x & x & x & x\\ x & x & x & x & x\\ x & x & I & x & x\\ x & x & x & x & x\\ x & x & x & x & x \end{bmatrix} \)
conn = bp.connect.GridN(N=2, include_self=False)
conn(pre_size=size)
GridN(N=2, include_self=False, periodic_boundary=False)
Here is an example:
size = (4, 4)
conn = bp.connect.GridN(N=1, include_self=False)(pre_size=size)
res = conn.require('conn_mat')
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res)
nx.draw(G, with_labels=True)
plt.show()

Built-in random connections#
brainpy.connect.FixedProb#
For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

Class brainpy.connect.FixedProb
is inherited from TwoEndConnector
, and it receives three settings:
prob
: Fixed probability for connection with a pre-synaptic neuron for each post-synaptic neuron.include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=134)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[False, True, True, True],
[False, False, True, True],
[False, True, False, False],
[ True, False, False, False]], dtype=bool)
brainpy.connect.FixedPreNum#
Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

Class brainpy.connect.FixedPreNum
is inherited from TwoEndConnector
, and it receives three settings:
num
: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedPreNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[ True, True, False, False],
[False, True, False, True],
[ True, False, True, False],
[False, False, True, True]], dtype=bool)
brainpy.connect.FixedPostNum#
Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

Class brainpy.connect.FixedPostNum
is inherited from TwoEndConnector
, and it receives three settings:
num
: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedPostNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[ True, False, True, False],
[ True, True, False, False],
[False, False, True, True],
[False, True, False, True]], dtype=bool)
brainpy.connect.GaussianProb#
Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.
Specifically,
where \((x, y)\) is the position of the pre-synaptic neuron and \((x_c,y_c)\) is the position of the post-synaptic neuron.
For example, in a \(30 \textrm{x} 30\) two-dimensional networks, when \(\beta = \frac{1}{2\sigma^2} = 0.1\), the connection pattern is shown as the follows:

GaussianProb
is inherited from OneEndConnector
, and it receives four settings:
sigma
: (float) Width of the Gaussian function.encoding_values
: (optional, list, tuple, int, float) The value ranges to encode for neurons at each axis.periodic_boundary
: (bool) Whether the neuron encode the value space with the periodic boundary.normalize
: (bool) Whether normalize the connection probability.include_self
: (bool) Whether create the conn at the same position.seed
: (bool) The random seed.
conn = bp.connect.GaussianProb(sigma=2, periodic_boundary=True, normalize=True, include_self=True, seed=21)
conn(pre_size=10)
conn.require('conn_mat')
Array([[ True, True, False, True, False, False, False, False, True,
True],
[ True, True, True, True, False, False, False, False, False,
True],
[ True, True, True, True, False, False, False, False, False,
True],
[False, True, True, True, True, False, False, False, False,
False],
[False, False, False, True, True, True, False, False, False,
False],
[False, False, True, False, True, True, True, True, False,
False],
[False, False, False, True, False, True, True, True, False,
False],
[ True, False, False, False, False, True, True, True, False,
True],
[False, True, False, False, False, False, True, True, True,
True],
[ True, False, True, False, False, True, False, True, True,
True]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.SmallWorld#
SmallWorld
is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:
[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.
Currently, SmallWorld
only support a one-dimensional network with the ring structure. It receives four settings:
num_neighbor
: the number of the nearest neighbors to connect.prob
: the probability of rewiring each edge.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.include_self
: whether allow to connect to itself.
conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, True, True, False, False, False, False, False, True,
True],
[ True, False, True, False, False, False, True, False, False,
True],
[ True, True, False, True, True, False, True, False, False,
False],
[False, False, True, False, True, True, True, False, False,
False],
[False, False, True, True, False, False, True, False, False,
False],
[False, False, False, True, False, False, True, True, False,
False],
[False, True, True, True, True, True, False, False, True,
False],
[False, False, False, False, False, True, False, False, True,
True],
[ True, False, False, False, False, False, True, True, False,
True],
[ True, True, False, False, False, False, False, True, True,
False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.ScaleFreeBA#
ScaleFreeBA
is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA
receives the following settings:
m
: Number of edges to attach from a new node to existing nodes.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.
conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False, False, False, True, True, False, False,
False],
[False, False, False, False, False, True, True, True, False,
True],
[False, False, False, False, False, True, True, True, True,
False],
[False, False, False, False, False, True, True, True, True,
True],
[False, False, False, False, False, True, False, False, False,
False],
[ True, True, True, True, True, False, True, True, True,
True],
[ True, True, True, True, False, True, False, True, True,
True],
[False, True, True, True, False, True, True, False, True,
True],
[False, False, True, True, False, True, True, True, False,
False],
[False, True, False, True, False, True, True, True, False,
False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.ScaleFreeBADual#
ScaleFreeBADual
is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:
p
: The probability of attaching \(m_1\) edges (as opposed to \(m_2\) edges).m1
: Number of edges to attach from a new node to existing nodes with probability \(p\).m2
: Number of edges to attach from a new node to existing nodes with probability \(1-p\).directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.
conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False, False, False, True, True, True, False,
True],
[False, False, False, False, False, True, False, True, True,
False],
[False, False, False, False, False, True, True, True, True,
True],
[False, False, False, False, False, False, False, False, False,
False],
[False, False, False, False, False, False, False, False, False,
False],
[ True, True, True, False, False, False, True, True, True,
True],
[ True, False, True, False, False, True, False, True, True,
True],
[ True, True, True, False, False, True, True, False, True,
True],
[False, True, True, False, False, True, True, True, False,
False],
[ True, False, True, False, False, True, True, True, False,
False]], dtype=bool)
brainpy.connect.PowerLaw#
PowerLaw
is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:
m
: the number of random edges to add for each new nodep
: Probability of adding a triangle after adding a random edgedirected
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.
conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False, True, True, False, True, True, True,
False],
[False, False, False, True, False, True, False, False, False,
False],
[False, False, False, True, True, True, True, True, False,
True],
[ True, True, True, False, True, False, False, False, False,
False],
[ True, False, True, True, False, True, False, False, True,
False],
[False, True, True, False, True, False, True, True, False,
True],
[ True, False, True, False, False, True, False, False, True,
False],
[ True, False, True, False, False, True, False, False, False,
True],
[ True, False, False, False, True, False, True, False, False,
False],
[False, False, True, False, False, True, False, True, False,
False]], dtype=bool)
Encapsulate your existing connections#
BrainPy also allows users to encapsulate existing connections with convenient class interfaces. Users can provide connection types as:
Index projection;
Dense matrix;
Sparse matrix.
Then users should provide pre_size
and post_size
information in order to instantiate the connection. In such a way, based on the following connection classes, users can generate any other synaptic structures (such like pre2post
, pre2syn
, conn_mat
, etc.) easily.
bp.conn.IJConn
#
Here, let’s take a simple connection as an example. In this example, we create a connection which receives users’ handful index projection by using bp.conn.IJConn
.
pre_list = np.array([0, 1, 2])
post_list = np.array(([0, 0, 0]))
conn = bp.conn.IJConn(i=pre_list, j=post_list)
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
Array([[ True, False, False],
[ True, False, False],
[ True, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
conn.requires('pre2post')
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\ops\scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=uint32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
(Array([0, 0, 0], dtype=int32), Array([0, 1, 2, 3, 3, 3], dtype=int32))
conn.requires('pre2syn')
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2, 3, 3, 3], dtype=int32))
bp.conn.MatConn
#
In next example, we create a connection which receives user’s handful dense connection matrix by using bp.conn.MatConn
.
bp.math.random.seed(123)
conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_))
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
Array([[False, True, True],
[ True, True, True],
[ True, True, True],
[False, True, True],
[False, False, True]], dtype=bool)
conn.requires('pre2post')
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\ops\scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=uint32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
(Array([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 2], dtype=int32),
Array([ 0, 2, 5, 8, 10, 11], dtype=int32))
conn.require('pre2syn')
(Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int32),
Array([ 0, 2, 5, 8, 10, 11], dtype=int32))
bp.conn.SparseMatConn
#
In last example, we create a connection which receives user’s handful sparse connection matrix by using bp.conn.sparseMatConn
from scipy.sparse import csr_matrix
conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
sparse_mat = csr_matrix(conn_mat)
conn = bp.conn.SparseMatConn(sparse_mat)
conn = conn(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1])
conn.requires('conn_mat')
Array([[ True, False, True],
[ True, False, True],
[ True, False, True],
[False, True, True],
[ True, True, False]], dtype=bool)
conn.requires('pre2post')
(Array([0, 2, 0, 2, 0, 2, 1, 2, 0, 1], dtype=int32),
Array([ 0, 2, 4, 6, 8, 10], dtype=int32))
conn.requires('post2syn')
(Array([0, 2, 4, 8, 6, 9, 1, 3, 5, 7], dtype=int32),
Array([ 0, 4, 6, 10], dtype=int32))
Using NetworkX to provide connections and pass into Connector
#
NetworkX is a Python package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks.
Users can design their own complex netork by using NetworkX
.
import networkx as nx
G = nx.Graph()
By definition, a Graph
is a collection of nodes (vertices) along with identified pairs of nodes (called edges, links, etc).
To learn more about NetowrkX
, please check the official documentation: NetworkX tutorial
Using class brainpy.connect.MatConn
to construct connections is recommended here.
Dense adjacency matrix: a two-dimensional ndarray.
Here gives an example to illustrate how to transform a random graph into your synaptic connections by using dense adjacency matrix.
G = nx.fast_gnp_random_graph(5, 0.5) # initialize a random graph G
B = nx.adjacency_matrix(G)
A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G
print('dense adjacency matrix:')
print(A)
nx.draw(G, with_labels=True)
plt.show()
dense adjacency matrix:
[[0 1 1 0 1]
[1 0 1 0 0]
[1 1 0 1 0]
[0 0 1 0 1]
[1 0 0 1 0]]
C:\Users\adadu\AppData\Local\Temp\ipykernel_2488\2020588093.py:2: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
B = nx.adjacency_matrix(G)
C:\Users\adadu\AppData\Local\Temp\ipykernel_2488\2020588093.py:3: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G

Users can use class MatConn
inherited from TwoEndConnector
to construct connections. A dense adjacency matrix should be passed in when initializing MatConn
class.
Note that when calling the instance of the class, users should pass in two parameters: pre_size
and post_size
.
In this case, users can use the shape of dense adjacency matrix as the parameters.
conn = bp.connect.MatConn(A)(pre_size=A.shape[0], post_size=A.shape[1])
res = conn.require('conn_mat')
print(res)
Array([[False, True, True, False, True],
[ True, False, True, False, False],
[ True, True, False, True, False],
[False, False, True, False, True],
[ True, False, False, True, False]], dtype=bool)
Customize your connections#
BrainPy allows users to customize their connections. The following requirements should be satisfied:
Your connection class should inherit from
brainpy.connect.TwoEndConnector
orbrainpy.connect.OneEndConnector
.__init__
function should be implemented and essential parameters should be initialized.Users should also overwrite
build_csr()
,build_coo()
orbuild_mat()
function to describe how to build your connection.
Let’s take an example to illustrate the details of customization.
class FixedProb(bp.connect.TwoEndConnector):
"""Connect the post-synaptic neurons with fixed probability.
Parameters
----------
prob : float
The conn probability.
include_self : bool
Whether to create (i, i) connection.
seed : optional, int
Seed the random generator.
"""
def __init__(self, prob, include_self=True, seed=None):
super(FixedProb, self).__init__()
assert 0. <= prob <= 1.
self.prob = prob
self.include_self = include_self
self.seed = seed
self.rng = np.random.RandomState(seed=seed)
def build_csr(self):
ind = []
count = np.zeros(self.pre_num, dtype=np.uint32)
def _random_prob_conn(rng, pre_i, num_post, prob, include_self):
p = rng.random(num_post) <= prob
if (not include_self) and pre_i < num_post:
p[pre_i] = False
conn_j = np.asarray(np.where(p)[0], dtype=np.uint32)
return conn_j
for i in range(self.pre_num):
posts = _random_prob_conn(self.rng, pre_i=i, num_post=self.post_num,
prob=self.prob, include_self=self.include_self)
ind.append(posts)
count[i] = len(posts)
ind = np.concatenate(ind)
indptr = np.concatenate(([0], count)).cumsum()
return ind, indptr
Then users can initialize the your own connections as below:
conn = FixedProb(prob=0.5, include_self=True)(pre_size=5, post_size=5)
conn.require('conn_mat')
Array([[False, True, False, True, True],
[ True, False, True, True, True],
[False, True, False, False, False],
[ True, False, False, True, False],
[ True, True, False, False, False]], dtype=bool)
Synaptic Weights#
In a brain model, synaptic weights, the strength of the connection between presynaptic and postsynaptic neurons, are crucial to the dynamics of the model. In this section, we will illutrate how to build synaptic weights in a synapse model.
import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt
bp.math.set_platform('cpu')
bp.__version__
'2.3.0'
Creating Static Weights#
Some computational models focus on the network structure and its influence on network dynamics, thus not modeling neural plasticity for simplicity. In this condition, synaptic weights are fixed and do not change in simulation. They can be stored as a scalar, a matrix or a vector depending on the connection strength and density.
1. Storing weights with a scalar#
If all synaptic weights are designed to be the same, the single weight value can be stored as a scalar in the synpase model to save memory space.
weight = 1.
The weight
can be stored in a synapse model. When updating the synapse, this weight
is assigned to all synapses by scalar multiplication.
2. Storing weights with a matrix#
When the synaptic connection is dense and the synapses are assigned with different weights, weights can be stored in a matrix \(W\), where \(W(i, j)\) refers to the weight of presynaptic neuron \(i\) to postsynaptic neuron \(j\).
BrainPy provides brainpy.initialize.Initializer
(or brainpy.init
for short) for weight initialization as a matrix. The tutorial of brainpy.init.Initializer
is introduced later.
For example, a weight matrix can be constructed using brainpy.init.Uniform
, which initializes weights with a random distribution:
pre_size = (4, 4)
post_size = (3, 3)
uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init((pre_size, post_size))
print('shape of weights: {}'.format(weights.shape))
shape of weights: (16, 9)
Then, the weights can be assigned to a group of connections with the same shape. For example, an all-to-all connection matrix can be obtained by brainpy.conn.All2All()
whose tutorial is contained in Synaptic Connections:
conn = bp.conn.All2All()
conn(pre_size, post_size)
conn_mat = conn.requires('conn_mat') # request the connection matrix
Therefore, weights[i, j]
refers to the weight of connection (i, j)
.
i, j = (2, 3)
print('whether (i, j) is connected: {}'.format(conn_mat[i, j]))
print('synaptic weights of (i, j): {}'.format(weights[i, j]))
whether (i, j) is connected: True
synaptic weights of (i, j): 0.974509060382843
3. Storing weights with a vector#
When the synaptic connection is sparse, using a matrix to store synaptic weights is too wasteful. Instead, the weights can be stored in a vector which has the same length as the synaptic connections.

Weights can be assigned to the corresponding synapses as long as the they are aligned with each other.
size = 5
conn = bp.conn.One2One()
conn(size, size)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')
print('presynaptic neuron ids: {}'.format(pre_ids))
print('postsynaptic neuron ids: {}'.format(post_ids))
print('synapse ids: {}'.format(bm.arange(size)))
presynaptic neuron ids: [0 1 2 3 4]
postsynaptic neuron ids: [0 1 2 3 4]
synapse ids: [0 1 2 3 4]
The weight vector is aligned with the synapse vector, i.e. synapse ids :
weights = bm.random.uniform(0, 2, size)
for i in range(size):
print('weight of synapse {}: {}'.format(i, weights[i]))
weight of synapse 0: 1.1543986797332764
weight of synapse 1: 1.7815501689910889
weight of synapse 2: 0.6559045314788818
weight of synapse 3: 0.48931145668029785
weight of synapse 4: 0.19386005401611328
Conversion from a weight matrix to a weight vector#
For users who would like to obtain the weight vector from the weight matrix, they can first build a connection according to the non-zero elements in the weight matrix and then slice the weight matrix according to the connection:
weight_mat = np.array([[1., 1.5, 0., 0.5], [0., 2.5, 0., 0.], [2., 0., 3, 0.]])
print('weight matrix: \n{}'.format(weight_mat))
conn = bp.conn.MatConn(weight_mat)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')
weight_vec = weight_mat[pre_ids, post_ids]
print('weight_vector: \n{}'.format(weight_vec))
weight matrix:
[[1. 1.5 0. 0.5]
[0. 2.5 0. 0. ]
[2. 0. 3. 0. ]]
weight_vector:
[1. 1.5 0.5 2.5 2. 3. ]
Note
However, it is not recommended to use this function when the connection is sparse and of a large scale, because generating the weight matrix will take up too much space.
Creating Dynamic Weights#
Sometimes users may want to realize neural plasticity in a brain model, which requires the synaptic weights to change during simulation. In this condition, weights should be considered as variables, thus defined as brainpy.math.Variable
. If it is packed in a synapse model, weight updating should be realized in the update(_t, _dt)
function of the synapse model.
weights = bm.Variable(bm.ones(10))
weights
Variable([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)
Built-in Weight Initializers#
Base Class: bp.init.Initializer#
The base class of weight initializers are brainpy.initialize.Initializer
, which can be accessed by the shortcut bp.init
. All initializers, built-in or costumized, should inherit the Initializer
class.
Weight initialization is implemented in the __call__
function, so that it can be realized automatically when the initializer is called. The __call__
function has a shape
parameter that has different meanings in the following two superclasses and returns a weight matrix.
Superclass 1: bp.init.InterLayerInitializer#
The InterLayerInitializer
is an abstract subclass of Initializer
. Its subclasses initialize the weights between two fully connected layers. The shape
parameter of the __call__
function should be a 2-element tuple \((m, n)\), which refers to the number of presynaptic neurons \(m\) and of postsynaptic neurons \(n\). The output of the __call__
function is a bp.math.ndarray
with the shape of \((m, n)\), where the value at \((i, j)\) is the initialized weight of the presynaptic neuron \(i\) to postsynpatic neuron \(j\).
Superclass 2: bp.init.IntraLayerInitializer#
The IntraLayerInitializer
is also an abstract subclass of Initializer
. Its subclasses initialize the weights within a single layer. The shape
parameter of the __call__
function refers to the the structure of the neural population \((n_1, n_2, ..., n_d)\). The __call__
function returns a 2-D bp.math.ndarray
with the shape of \((\prod_{k=1}^d n_k, \prod_{k=1}^d n_k)\). In the 2-D array, the value at \((i, j)\) is the initialized weight of neuron \(i\) to neuron \(j\) of the flattened neural sequence.
1. Built-In Regular Initializers#
Regular initializers all belong to InterLayerInitializer
and initialize the connection weights between two layers with a regular pattern. There are ZeroInit
, OneInit
, and Identity
initializers in built-in regular initializers. Here we show how to use the OneInit
initializer. The remaining two classes are used in a similar way.
# visualization
def mat_visualize(matrix, cmap=plt.cm.get_cmap('coolwarm')):
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
# 'OneInit' initializes all the weights with the same value
shape = (5, 6)
one_init = bp.init.OneInit(value=2.5)
weights = one_init(shape)
print(weights)
Array([[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5]], dtype=float32)
2. Built-In Random Initializers#
Random initializers all belong to InterLayerInitializer
and initialize the connection weights between two layers with a random distribution. There are Normal
, Uniform
, Orthogonal
and other initializers in built-in regular initializers. Here we show how to use the Normal
and Uniform
initializers.
bp.init.Normal
This initializer initializes the weights with a normal distribution. The variance of the distribution changes according to the scale
parameter. In the following example, 10 presynaptic neurons are fully connected to 20 postsynaptic neurons with random weight values:
shape = (10, 20)
normal_init = bp.init.Normal(scale=1.0)
weights = normal_init(shape)
mat_visualize(weights)

bp.init.Uniform
This initializer resembles brainpy.init.Normal
but initializes the weights with a uniform distribution.
uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init(shape)
mat_visualize(weights)

3. Built-In Decay Initializers#
Decay initializers all belong to IntraLayerInitializer
and initialize the connection weights within a layer with a decay function according to the neural distance. There are GaussianDecay
and DOGDecay
initializers in built-in decay initializers. Below are examples of how to use them.
brainpy.training.initialize.GaussianDecay
This initializer creates a Gaussian connectivity pattern within a population of neurons, where the weights decay with a gaussian function. Specifically, for any pair of neurons \( (i, j) \), the weight is computed as
where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).
The example below is a neural population with the size of \( 5 \times 5 \). Note that this shape is the structure of the target neural population, not the size of presynaptic and postsynaptic neurons.
size = (5, 5)
gaussian_init = bp.init.GaussianDecay(sigma=2., max_w=10., include_self=True)
weights = gaussian_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (25, 25)
Self-connections are created if include_self=True
. The connection weights of neuron \(i\) with others are stored in row \(i\) of weights
. For instance, the connection weights of neuron(1, 2) to other neurons are stored in weights[7]
(\(5 \times 1 +2 = 7\)). After reshaping, the weights are:
mat_visualize(weights[0].reshape(size), cmap=plt.cm.get_cmap('Reds'))

brainpy.training.initialize.DOGDecay
This initializer creates a Difference-Of-Gaussian (DOG) connectivity pattern within a population of neurons. Specifically, for the given pair of neurons \( (i, j) \), the weight between them is computed as
where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 10 \times 12 \):
size = (10, 12)
dog_init = bp.init.DOGDecay(sigmas=(1., 3.), max_ws=(10., 5.), min_w=0.1, include_self=True)
weights = dog_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (120, 120)
Weights smaller than min_w
will not be created. min_w
\( = 0.005 \times min( \) max_ws
\( ) \) if it is not assigned with a value.
The organization of weights
is similar to that in the GaussianDecay
initializer. For instance, the connection weights of neuron (3, 4) to other neurons after reshaping are shown as below:
mat_visualize(weights[3*12+4].reshape(size), cmap=plt.cm.get_cmap('Reds'))

Customizing your initializers#
BrainPy also allows users to customize the weight initializers of their own. When customizing a initializer, users should follow the instructions below:
Your initializer should inherit
brainpy.initialize.Initializer
.Override the
__call__
funtion, to which theshape
parameter should be given.
Here is an example of creating an inter-layer initializer that initialize the weights as follows:
class LinearDecay(bp.init.InterLayerInitializer):
def __init__(self, max_w, sigma=1.):
self.max_w = max_w
self.sigma = sigma
def __call__(self, shape, dtype=None):
mat = bp.math.zeros(shape, dtype=dtype)
n_pre, n_post = shape
seq = np.arange(n_pre)
current_w = self.max_w
for i in range(max(n_pre, n_post)):
if current_w <= 0:
break
seq_plus = ((seq + i) >= 0) & ((seq + i) < n_post)
seq_minus = ((seq - i) >= 0) & ((seq - i) < n_post)
mat[seq[seq_plus], (seq + i)[seq_plus]] = current_w
mat[seq[seq_minus], (seq - i)[seq_minus]] = current_w
current_w -= self.sigma
return mat
shape = (10, 15)
lin_init = LinearDecay(max_w=5, sigma=1.)
weights = lin_init(shape)
mat_visualize(weights, cmap=plt.cm.get_cmap('Reds'))

Note
Note that customized initializers, or brainpy.init.Initializer
, is not limited to returning a matrix. Although currently all the built-in initializers use matrix to store weights, they can also be designed to return a vector to store synaptic weights.
Gradient Descent Optimizers#
Gradient descent is one of the most popular optimization methods. At present, gradient descent optimizers, combined with the loss function, are the key to machine learning, especially deep learning. In this section, we are going to understand:
how to use optimizers in BrainPy?
how to customize your own optimizer?
import brainpy as bp
import brainpy.math as bm
# bp.math.set_platform('cpu')
bp.__version__
'2.3.0'
import matplotlib.pyplot as plt
Optimizers in BrainPy#
The basic optimizer class in BrainPy is brainpy.optimizers.Optimizer
, which inludes the following optimizers:
SGD
Momentum
Nesterov momentum
Adagrad
Adadelta
RMSProp
Adam
All supported optimizers can be inspected through the brainpy.math.optimizers APIs.
Generally, an optimizer initialization receives the learning rate lr
, the trainable variables train_vars
, and other hyperparameters for the specific optimizer.
lr
can be a float, or an instance ofbrainpy.optim.Scheduler
.train_vars
should be a dict of Variable.
Here we launch a SGD optimizer.
a = bm.Variable(bm.ones((5, 4)))
b = bm.Variable(bm.zeros((3, 3)))
op = bp.optim.SGD(lr=0.001, train_vars={'a': a, 'b': b})
When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update()
method.
op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})
print('a:', a)
print('b:', b)
a: Variable([[0.9993626 , 0.9997406 , 0.999853 , 0.999312 ],
[0.9993036 , 0.99934477, 0.9998294 , 0.9997739 ],
[0.99900717, 0.9997449 , 0.99976104, 0.99953616],
[0.9995185 , 0.99917144, 0.9990044 , 0.99914813],
[0.9997468 , 0.9999408 , 0.99917686, 0.9999825 ]], dtype=float32)
b: Variable([[-0.00034196, -0.00046545, -0.00027317],
[-0.00045028, -0.00076825, -0.00026088],
[-0.0007135 , -0.00020507, -0.00073902]], dtype=float32)
You can process the gradients before applying them. For example, we clip the graidents by the maximum L2-norm.
grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}
grads_pre
{'a': Array([[0.6356058 , 0.10750175, 0.93578255, 0.2557603 ],
[0.77525663, 0.8615701 , 0.35919654, 0.6861898 ],
[0.9569112 , 0.98981357, 0.3033744 , 0.62852013],
[0.36589646, 0.86694443, 0.6335902 , 0.44947362],
[0.01782513, 0.11465573, 0.5505476 , 0.56196713]], dtype=float32),
'b': Array([[0.2326113 , 0.14437485, 0.6543677 ],
[0.46068823, 0.9811108 , 0.30460846],
[0.261765 , 0.71705794, 0.6173099 ]], dtype=float32)}
grads_post = bm.clip_by_norm(grads_pre, 1.)
grads_post
{'a': Array([[0.22753015, 0.0384828 , 0.33498552, 0.09155546],
[0.2775215 , 0.30841944, 0.12858291, 0.24563788],
[0.34254903, 0.3543272 , 0.10860006, 0.22499368],
[0.13098131, 0.3103433 , 0.22680864, 0.16089973],
[0.00638093, 0.04104374, 0.19708155, 0.20116945]], dtype=float32),
'b': Array([[0.14066657, 0.08730751, 0.39571446],
[0.27859107, 0.5933052 , 0.18420528],
[0.15829663, 0.433625 , 0.3733046 ]], dtype=float32)}
op.update(grads_post)
print('a:', a)
print('b:', b)
a: Variable([[0.9991351 , 0.9997021 , 0.99951804, 0.99922043],
[0.99902606, 0.9990364 , 0.99970084, 0.9995283 ],
[0.9986646 , 0.99939054, 0.99965245, 0.99931115],
[0.9993875 , 0.9988611 , 0.9987776 , 0.99898726],
[0.9997404 , 0.99989974, 0.9989798 , 0.9997813 ]], dtype=float32)
b: Variable([[-0.00048263, -0.00055276, -0.00066889],
[-0.00072887, -0.00136155, -0.00044508],
[-0.00087179, -0.0006387 , -0.00111233]], dtype=float32)
Note
Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, your dyn_vars
in bm.jit()
should include variables in Optimzier.vars()
.
op.vars() # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable([2], dtype=int32)}
bp.optim.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Momentum has velocity variables
{'Momentum0.a_v': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Momentum0.b_v': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32),
'Constant1.step': Variable([0], dtype=int32)}
bp.optim.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Adam has more variables
{'Adam0.a_m': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Adam0.b_m': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32),
'Adam0.a_v': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Adam0.b_v': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32),
'Constant2.step': Variable([0], dtype=int32)}
Creating A Self-Customized Optimizer#
To create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer
class and override the following methods:
__init__()
: init function that receives the learning rate (lr
) and trainable variables (train_vars
). Do not forget to register your dynamical changed variables intoimplicit_vars
.update(grads)
: update function that computes the updated parameters.
The general structure is shown below:
class CustomizeOp(bp.optim.Optimizer):
def __init__(self, lr, train_vars, *params, **other_params):
super(CustomizeOp, self).__init__(lr, train_vars)
# customize your initialization
def update(self, grads):
# customize your update logic
pass
Schedulers#
Scheduler seeks to adjust the learning rate during training through reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.
Here we set up an exponential decay scheduler, in which the learning rate will decay exponentially along the training step.
sc = bp.optim.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
plt.plot(steps, rates)
plt.xlabel('Train Step')
plt.ylabel('Learning Rate')
plt.show()
steps = bm.arange(1000)
rates = sc(steps)
show(steps, rates)

After Optimizer initialization, the learning rate self.lr
will always be an instance of bm.optimizers.Scheduler
. A scalar float learning rate initialization will result in a Constant
scheduler.
op.lr
Constant(0.001)
One can get the current learning rate value by calling Scheduler.__call__(i=None)
.
If
i
is not provided, the learning rate value will be evaluated at the built-in trainingstep
.Otherwise, the learning rate value will be evaluated at the given step
i
.
op.lr()
0.001
In BrainPy, several commonly used learning rate schedulers are used:
Constant
ExponentialDecay
InverseTimeDecay
PolynomialDecay
PiecewiseConstant
For more details, please see the brainpy.math.optimizers APIs.
# InverseTimeDecay scheduler
rates = bp.optim.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)

# PolynomialDecay scheduler
rates = bp.optim.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)

Creating a Self-Customized Scheduler#
If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler
class and override the following methods:
__init__()
: the init function.__call__(i=None)
: the learning rate value evalution.
class CustomizeScheduler(bp.optim.Scheduler):
def __init__(self, lr, *params, **other_params):
super(CustomizeScheduler, self).__init__(lr)
# customize your initialization
def __call__(self, i=None):
# customize your update logic
pass
Saving and Loading#
Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.
import brainpy as bp
bp.math.set_platform('cpu')
Saving and loading variables#
Model saving and loading in BrainPy are implemented with .save_states()
and .load_states()
functions.
BrainPy supports saving and loading model variables with various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’s pickle utility).mat
(Matlab file format)
Here’s a simple example:
class EINet(bp.dyn.Network):
def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = bp.models.LIF(num_exc, **pars, method=method)
I = bp.models.LIF(num_inh, **pars, method=method)
E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
E2E = bp.models.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
E2I = bp.models.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
I2E = bp.models.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
I2I = bp.models.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
net = EINet()
import os
if not os.path.exists('./data'):
os.makedirs('./data')
# model saving
net.save_states('./data/net.h5')
# model loading
net.load_states('./data/net.h5')
.save_states(filename, all_var=None)
function receives a string to specify the output file name. Ifall_vars
is not provided, BrainPy will retieve all variables in the model though the relative path..load_states(filename, verbose, check_missing)
function receives several arguments. The first is a string of the output file name. The second “verbose” specifies whether report the loading progress. The final argument “check_missing” will warn the variables of the model which missed in the output file.
# model loading with warning and checking
net.load_states('./data/net.h5', verbose=True)
WARNING:brainpy.base.io:There are variable states missed in ./data/net.h5. The missed variables are: ['ExpCOBA0.pre.V', 'ExpCOBA0.pre.input', 'ExpCOBA0.pre.refractory', 'ExpCOBA0.pre.spike', 'ExpCOBA0.pre.t_last_spike', 'ExpCOBA1.pre.V', 'ExpCOBA1.pre.input', 'ExpCOBA1.pre.refractory', 'ExpCOBA1.pre.spike', 'ExpCOBA1.pre.t_last_spike', 'ExpCOBA1.post.V', 'ExpCOBA1.post.input', 'ExpCOBA1.post.refractory', 'ExpCOBA1.post.spike', 'ExpCOBA1.post.t_last_spike', 'ExpCOBA2.pre.V', 'ExpCOBA2.pre.input', 'ExpCOBA2.pre.refractory', 'ExpCOBA2.pre.spike', 'ExpCOBA2.pre.t_last_spike', 'ExpCOBA2.post.V', 'ExpCOBA2.post.input', 'ExpCOBA2.post.refractory', 'ExpCOBA2.post.spike', 'ExpCOBA2.post.t_last_spike', 'ExpCOBA3.pre.V', 'ExpCOBA3.pre.input', 'ExpCOBA3.pre.refractory', 'ExpCOBA3.pre.spike', 'ExpCOBA3.pre.t_last_spike'].
Loading E.V ...
Loading E.input ...
Loading E.refractory ...
Loading E.spike ...
Loading E.t_last_spike ...
Loading ExpCOBA0.g ...
Loading ExpCOBA0.pre_spike.data ...
Loading ExpCOBA0.pre_spike.in_idx ...
Loading ExpCOBA0.pre_spike.out_idx ...
Loading ExpCOBA1.g ...
Loading ExpCOBA1.pre_spike.data ...
Loading ExpCOBA1.pre_spike.in_idx ...
Loading ExpCOBA1.pre_spike.out_idx ...
Loading ExpCOBA2.g ...
Loading ExpCOBA2.pre_spike.data ...
Loading ExpCOBA2.pre_spike.in_idx ...
Loading ExpCOBA2.pre_spike.out_idx ...
Loading ExpCOBA3.g ...
Loading ExpCOBA3.pre_spike.data ...
Loading ExpCOBA3.pre_spike.in_idx ...
Loading ExpCOBA3.pre_spike.out_idx ...
Loading I.V ...
Loading I.input ...
Loading I.refractory ...
Loading I.spike ...
Loading I.t_last_spike ...
Note
By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables.
Custom saving and loading#
You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through .vars()
. Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of Variables.
The only gotcha to pay attention to is to avoid saving duplicated variables.
Inputs Construction#
In this section, we are going to talk about stimulus inputs.
import brainpy as bp
import brainpy.math as bm
bp.__version__
'2.3.0'
Input construction functions#
Like electrophysiological experiments, model simulation also needs various kind of inputs. BrainPy provide several convenient input functions to help users construct input currents.
1. brainpy.inputs.section_input()
#
brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input()
(see below).
Sometimes, we need input currents with different values in different periods. For example, if you want to get an input that is 0 in the first 100 ms, 1 in the next 300 ms, and 0 again from the last 100 ms, you can define:
current1, duration = bp.inputs.section_input(values=[0, 1., 0.],
durations=[100, 300, 100],
return_length=True,
dt=0.1)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Where values
receive a list/arrray of the current values in each section and durations
receives a list/array of the duration of each section. The function returns a tensor as the current, the length of which is duration
\(/\mathrm{d}t\) (if not specified, \(\mathrm{d}t=0.1 \mathrm{ms}\)). We can visualize the current input by:
import numpy as np
import matplotlib.pyplot as plt
def show(current, duration, title):
ts = np.arange(0, duration, bm.get_dt())
plt.plot(ts, current)
plt.title(title)
plt.xlabel('Time [ms]')
plt.ylabel('Current Value')
plt.show()
show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')

2. brainpy.inputs.constant_input()
#
brainpy.inputs.constant_input() function helps users to format constant currents in several periods.
We can generate the above input current with constant_input()
by:
current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])
Where each tuple in the list contains the value and duration of the input in this section.
show(current2, duration, '[(0, 100), (1, 300), (0, 100)]')

3. brainpy.inputs.spike_input()
#
brainpy.inputs.spike_input() constructs an input containing a series of short-time spikes. It receives the following settings:
sp_times
: The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.sp_lens
: The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same withsp_times
.sp_sizes
: The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same withsp_times
.duration
: The total current duration.dt
: The time step precision. The default is None (will be initialized as the defaultdt
step).
For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, where each spike lasts 1 ms and the average value for each spike is 0.5, then you can define the current by:
current3 = bp.inputs.spike_input(
sp_times=[10, 20, 30, 200, 300],
sp_lens=1., # can be a list to specify the spike length at each point
sp_sizes=0.5, # can be a list to specify the spike current size at each point
duration=400.)
show(current3, 400, 'Spike Input Example')

4. brainpy.inputs.ramp_input()
#
brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:
c_start
: The minimum (or maximum) current size.c_end
: The maximum (or minimum) current size.duration
: The total duration.t_start
: The ramped current start time-point.t_end
: The ramped current end time-point. Default is the None.dt
: The current precision.
We illustrate the usage of brainpy.inputs.ramp_input()
by two examples.
In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (500 ms).
duration = 500
current4 = bp.inputs.ramp_input(0, 1, duration)
show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
r'$t_{start}$=0, $t_{end}$=None' % (duration))

In the second example, we increase the current size from 0. to 1. from the 100 ms to 400 ms.
duration, t_start, t_end = 500, 100, 400
current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)
show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))

5. brainpy.inputs.wiener_process
#
brainpy.inputs.wiener_process() is used to generate the basic Wiener process \(dW\), i.e. random numbers drawn from \(N(0, \sqrt{dt})\).
duration = 200
current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.)
show(current6, duration, 'Wiener Process')

6. brainpy.inputs.ou_process
#
brainpy.inputs.ou_process() is used to generate the noise time series from Ornstein-Uhlenback process \(\dot{x} = (\mu - x)/\tau \cdot dt + \sigma\cdot dW\).
duration = 200
current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.)
show(current7, duration, 'Ornstein-Uhlenbeck Process')

7. brainpy.inputs.sinusoidal_input
#
brainpy.inputs.sinusoidal_input() can help to generate sinusoidal inputs.
duration = 2000
current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration, t_start=100., )
show(current8, duration, 'Sinusoidal Input')

8. brainpy.inputs.square_input
#
brainpy.inputs.square_input() can help to generate oscillatory square inputs.
duration = 2000
current9 = bp.inputs.square_input(amplitude=1., frequency=2.0,
duration=duration, t_start=100)
show(current9, duration, 'Square Input')

More complex inputs#
Because the current input is stored as a tensor, a complex input can be realized by the combination of several simple currents.
show(current1 + current5, 500, 'A Complex Current Input')

General properties of input functions#
1. Every input function receives a dt
specification.
If dt
is not provided, input functions will use the default dt
in the whole BrainPy system.
I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1)
I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01)
print('I1.shape: {}'.format(I1.shape))
print('I2.shape: {}'.format(I2.shape))
I1.shape: (600,)
I2.shape: (6000,)
2. All input functions can automatically broadcast the current shapes if they are heterogenous among different periods.
For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example:
current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
durations=[100, 300, 100])
current.shape
(5000, 3, 10)
How to cite BrainPy?#
If BrainPy has been significant in your research, and you would like to acknowledge the project in your academic publication, we suggest citing the following papers:
If you are using BrainPy=2.x
, please use:
Chaoming Wang, Xiaoyu Chen, Tianqiu Zhang, Si Wu. BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming. bioRxiv 2022.10.28.514024; doi: https://doi.org/10.1101/2022.10.28.514024
@article {Wang2022brainpy,
author = {Wang, Chaoming and Chen, Xiaoyu and Zhang, Tianqiu and Wu, Si},
title = {BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming},
elocation-id = {2022.10.28.514024},
year = {2022},
doi = {10.1101/2022.10.28.514024},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2022/10/28/2022.10.28.514024},
eprint = {https://www.biorxiv.org/content/early/2022/10/28/2022.10.28.514024.full.pdf},
journal = {bioRxiv}
}
If you are using BrainPy=1.x
, please use:
Wang, C., Jiang, Y., Liu, X., Lin, X., Zou, X., Ji, Z., & Wu, S. (2021, December). A Just-In-Time Compilation Approach for Neural Dynamics Simulation. In International Conference on Neural Information Processing (pp. 15-26). Springer, Cham.
@inproceedings{wang2021just,
title={A Just-In-Time Compilation Approach for Neural Dynamics Simulation},
author={Wang, Chaoming and Jiang, Yingqian and Liu, Xinyu and Lin, Xiaohan and Zou, Xiaolong and Ji, Zilong and Wu, Si},
booktitle={International Conference on Neural Information Processing},
pages={15--26},
year={2021},
organization={Springer}
}
How is brainpy
different from other frameworks?#
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
bp.__version__
BrainPy vs Brian2/NEST/NEURON …#
Different from traditional brain simulators (most of them employ a descriptive language for programming brain dynamics models), BrainPy aims to provide the full supports for brain dynamics modeling.
Currently, brain dynamics modeling is far beyond simulation. There are many new modeling approaches which take inspiration from the machine learning community. Moreover, it has also inspired the new development of brain-inspired computation.
These new advances cannot be captured by the traditional brain simulators. Therefore, BrainPy aims to provide an ecosystem for brain dynamics modeling, in which users can build various models easily and flexibly, and extend new modeling approach conveniently, etc.
The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!
Based on this, BrainPy provides an integrated platform for brain dynamics modeling, including
universal model building
dynamics simulation
dynamics training
dynamics analysis
Such integrative framework may help users to study brain dynamics comprehensively.
BrainPy vs JAX/Numba#
BrainPy relies on JAX and Numba. But it also has important aspects which are different from them.
JAX and Numba are excellent JIT compilers in Python. However, they are designed to work only on pure Python functions.
Most computational neuroscience models have too many parameters and variables to manage using functions only. Therefore, BrainPy provides an object-oriented programming interface for brain dynamics modeling. There object-oriented transformations are implemented in brainpy.math
module.
brainpy.math
is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.
There are important differences between brainpy.math
and JAX and JAX related frameworks.
Specifically, brainpy.math
provides:
Numpy-like ndarray.
Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray
structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y
. To overcome these drawbacks, brainpy.math
provides Array
that can be used in the same way as numpy ndarray.
b = bm.arange(5)
b
Array([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
Array([5, 1, 2, 3, 4], dtype=int32)
Numpy-like random sampling.
JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math
provides brainpy.math.random
for random sampling just like the numpy.random
module. For example:
# random sampling in "brainpy.math.random"
bm.random.seed(12345)
bm.random.random(5)
Array([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ], dtype=float32)
bm.random.normal(0., 2., 5)
Array([-1.5375282, -0.5970201, -2.272839 , 3.233081 , -0.2738593], dtype=float32)
For more details, please see the Arrays tutorial.
JAX transformations on class objects.
OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math
extends JAX transformations to Python classes. Details please see BrainPy Concept of Object-oriented Transformation.
BrainPy Ecosystem for Brain Dynamics Modeling#
BrainPy aims to build a complete ecosystem for brain dynamics modeling.
Although it has a far way to go, currently we make a progress in:
BrainPy#
Based on JAX, brainpy develops BrainPy to provide universal simulation, training and analysis engine, which serves as a foundation of the whole project. Specifically, BrainPy provides a object-oriented programming interface for brain dynamics modeling.
brainpy-examples#
brainpy-examples provides comprehensive examples for brain dynamics modeling with BrainPy. It implements many classical models introduced in the latest computational neuroscience and brain-inspired computation research.
brainpylib#
brainpylib aims to provide operators specialized for brain dynamics modeling. Brain dynamics features sparse connections and event-driven computation. brainpylib provides dedicated operators for such sparse computation and event-based computation. These operators can be used in computational neuroscience research as well as brain-inspired computation community.
brainpy-datasets#
brainpy-datasets aims to provide commonly used datasets in brain dynamics modeling, including neuromorphic datasets and cognitive tasks for training brain-like neural networks.
brainpy-largescale#
brainpy-largescale provides one solution for large-scale modeling. It enables multi-device running for BrainPy models.
brainpy
module#
Numerical Differential Integration#
|
Basic Integrator Class. |
|
Make a joint equation from multiple derivation functions. |
|
Structural runner for numerical integrators in brainpy. |
|
Numerical integration for ODEs. |
|
Numerical integration for SDEs. |
|
Numerical integration for FDEs. |
Building Dynamical System#
|
Base Dynamical System class. |
|
Container object which is designed to add other instances of DynamicalSystem. |
|
A sequential input-output module. |
|
Base class to model network objects, an alias of Container. |
|
Base class to model neuronal groups. |
|
Base class to model two-end synaptic connections. |
|
Base class for synaptic current output. |
|
Base class for synaptic short-term plasticity. |
|
Base class for synaptic long-term plasticity. |
|
Base class to model synaptic connections. |
|
Base class to model conductance-based neuron group. |
|
Abstract channel class. |
Simulating Dynamical System#
|
The runner for |
Training Dynamical System#
|
Structural Trainer for Dynamical Systems. |
|
The trainer implementing the back-propagation through time (BPTT) algorithm for training dyamical systems. |
|
The trainer implementing back propagation algorithm for feedforward neural networks. |
|
Online trainer for models with recurrent dynamics. |
|
FORCE learning. |
|
Offline trainer for models with recurrent dynamics. |
|
Trainer of ridge regression, also known as regression with Tikhonov regularization. |
Dynamical System Helpers#
|
Transform a single step |
brainpy.math
module#
Basis for Object-oriented Transformations#
|
The BrainPyObject class for whole BrainPy ecosystem. |
|
Transform a Python function as a |
|
A list to represent a dynamically changed numerical sequence in which its element can be changed during JIT compilation. |
|
An object to represent a dict of node in which its element can be changed during JIT compilation. |
|
A sequence variable, whose contents can be changed during JIT compilation. |
|
A dict variable, in which its element can be changed during JIT compilation. |
|
The pointer to specify the dynamical variable. |
|
The pointer to specify the parameter. |
|
The pointer to specify the trainable variable. |
|
Object-oriented Transformations#
|
Automatic gradient computation for functions or class objects. |
|
Take vector-valued gradients for function |
|
Extending automatic Jacobian (reverse-mode) of |
|
Extending automatic Jacobian (reverse-mode) of |
|
Extending automatic Jacobian (forward-mode) of |
|
Hessian of |
|
Make a for-loop function, which iterate over inputs. |
|
Make a while-loop function. |
|
Make a condition (if-else) function. |
|
Simple conditional statement (if-else) with instance of |
|
|
|
|
|
|
|
Transform a Python function to |
|
Transform a Python function into a |
|
JIT (Just-In-Time) compilation for class objects. |
|
Object-oriented JAX transformation for BrainPy computation. |
Brain Dynamics Dedicated Operators#
|
Product of a sparse CSR matrix and a dense event vector. |
|
Collect event information, including event indices, and event number. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a uniform distribution for its value. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a normal distribution for its value. |
|
Perform the \(Y=X@M\) operation, where \(X\), \(Y\) and \(M\) are matrices, and \(M\) is just-in-time randomly generated with a scalar weight at each position. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a uniform distribution for its value. |
|
Perform the \(Y=X@M\) operation, where \(X\), \(Y\) and \(M\) are matrices, and \(M\) is just-in-time randomly generated with a normal distribution for its value. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a uniform distribution for its value. |
|
Perform the \(y=M@v\) operation, where \(M\) is just-in-time randomly generated with a normal distribution for its value. |
|
|
|
The pre-to-post synaptic summation. |
|
The pre-to-post synaptic production. |
|
The pre-to-post synaptic maximization. |
|
The pre-to-post synaptic minimization. |
|
The pre-to-post synaptic mean computation. |
|
The pre-to-post event-driven synaptic summation with CSR synapse structure. |
|
The pre-to-post synaptic computation with event-driven summation. |
|
The pre-to-post synaptic computation with event-driven production. |
|
The pre-to-syn computation. |
|
The syn-to-post summation computation. |
|
The syn-to-post summation computation. |
|
The syn-to-post product computation. |
|
The syn-to-post maximum computation. |
|
The syn-to-post minimization computation. |
|
The syn-to-post mean computation. |
|
The syn-to-post softmax computation. |
|
Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. |
|
Product of COO sparse matrix and a dense vector using cuSPARSE algorithm. |
|
CSR sparse matrix product with a dense vector, which outperforms the cuSPARSE algorithm. |
|
Sparse matrix multiplication. |
|
convert pre_ids, post_ids to (indices, indptr). |
|
Given CSR (indices, indptr) return COO (row, col) |
|
|
|
Creating a XLA custom call operator. |
Activation Functions#
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU activation function |
|
Hard SiLU activation function |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
One-hot encodes the given indicies. |
|
Normalizes an array by subtracting mean and dividing by sqrt(var). |
|
|
|
Rectified Linear Unit 6 activation function. |
|
Sigmoid activation function. |
|
Soft-sign activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
SiLU activation function. |
|
SiLU activation function. |
|
Scaled exponential linear unit activation. |
|
|
Similar to |
Delay Variables#
|
Delay variable which has a fixed delay time length. |
|
Delay variable which has a fixed delay length. |
|
Neutral Time Delay. |
|
Neutral Length Delay. |
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
Environment Settings#
|
Set global default float type. |
Get the default float data type. |
|
|
Set global default integer type. |
|
Get the default int data type. |
|
Set global default boolean type. |
|
Get the default boolean data type. |
|
Set global default complex type. |
Get the default complex data type. |
|
|
Set the default numerical integrator precision. |
|
Get the numerical integrator precision. |
|
Set the default computing mode. |
|
Get the default computing mode. |
|
Set the default computation environment. |
|
Set the default computation environment. |
|
Changes platform to CPU, GPU, or TPU. |
Get the computing platform. |
|
By default, XLA considers all CPU cores as one device. |
|
|
Clear all on-device buffers. |
Disable pre-allocating the GPU memory. |
|
Disable pre-allocating the GPU memory. |
|
|
Default int type. |
|
Default float type. |
|
Context-manager that sets a computing environment for brain dynamics computation. |
|
Environment with the batching mode. |
|
Environment with the training mode. |
Computing Modes#
|
Base class for computation Mode |
Normal non-batching mode. |
|
|
Batching mode. |
|
Training mode requires data batching. |
Normal non-batching mode. |
|
Batching mode. |
|
Training mode requires data batching. |
Array Interoperability#
|
Convert the input to a |
|
Convert the input to a |
|
Convert the input to a |
|
Convert the input to a |
|
Convert the input to a |
Array Operators with NumPy Syntax#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Similar to |
|
Similar to |
|
Similar to |
|
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
|
|
|
|
|
Similar to |
|
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Return the shape of an array. |
|
Return the number of elements along a given axis. |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Return the indices for the lower-triangle of an (n, m) array. |
|
Similar to |
|
Return the indices for the upper-triangle of an (n, m) array. |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Similar to |
|
Similar to |
|
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Return a copy of an array sorted along the first axis. |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
Similar to |
|
|
Similar to |
|
Similar to |
|
Similar to |
Similar to |
|
Return the current print options. |
|
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Returns True if first argument is a typecode lower/equal in type hierarchy. |
|
Determine if the first argument is a subclass of the second argument. |
|
Similar to |
Similar to |
|
|
Similar to |
|
Context manager for setting print options. |
|
Set printing options. |
|
Similar to |
Similar to |
|
Similar to |
|
|
Similar to |
Similar to |
|
|
Similar to |
|
Add documentation to an existing object, typically one defined in C |
|
|
|
|
|
|
|
|
|
|
|
|
|
Display a message on a device. |
|
|
|
|
|
Get help information for a function, class, or module. |
|
Determine if a class is a subclass of a second class. |
|
|
|
Similar to |
|
|
|
|
|
|
|
|
|
|
Show libraries in the system on which NumPy was built. |
|
|
Return a description for the given data type code. |
|
|
|
|
|
|
|
|
|
Create a data type object. |
|
Machine limits for floating point types. |
|
Machine limits for integer types. |
Convert a string or number to a floating point number, if possible. |
|
Convert a string or number to a floating point number, if possible. |
|
Convert a string or number to a floating point number, if possible. |
|
|
Add a docstring to a built-in obj if possible. |
add_ufunc_docstring(ufunc, new_docstring) |
Array Operators with PyTorch Syntax#
|
Flattens input by reshaping it into a one-dimensional tensor. |
Similar to |
|
|
Returns a new tensor with a dimension of size one inserted at the specified position. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alias of |
Array Operators with TensorFlow Syntax#
Similar to |
|
|
Similar to |
|
Computes maximum of elements across dimensions of a tensor. |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Computes log(sum(exp(elements across dimensions of a tensor))). |
|
Similar to |
|
Similar to |
|
Similar to |
|
Computes the Euclidean norm of elements across dimensions of a tensor. |
|
Computes the sum along segments of a tensor divided by the sqrt(N). |
|
Computes the average along segments of a tensor. |
|
Computes the sum along segments of a tensor. |
|
Computes the product along segments of a tensor. |
|
Computes the maximum along segments of a tensor. |
|
Computes the minimum along segments of a tensor. |
|
Computes the average along segments of a tensor. |
|
|
|
|
|
|
|
|
|
Similar to |
|
Casts a tensor to a new type. |
brainpy.math.surrogate
module: Surrogate Gradient Functions#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Spike function with the sigmoid-shaped surrogate gradient. |
|
Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. |
|
Judge spiking state with a piecewise exponential function [1]_. |
|
Judge spiking state with a soft sign function. |
|
Judge spiking state with an arctan function. |
|
Judge spiking state with a nonzero sign log function. |
|
Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. |
|
Judge spiking state with a squarewave fourier series. |
|
Judge spiking state with the S2NN surrogate spiking function [1]_. |
|
Judge spiking state with the q-PseudoSpike surrogate function [1]_. |
|
Judge spiking state with the Leaky ReLU function. |
|
Judge spiking state with the Log-tailed ReLU function [1]_. |
|
Spike function with the ReLU gradient function [1]_. |
|
Spike function with the Gaussian gradient function [1]_. |
|
Spike function with the inverse-square surrogate gradient. |
|
Spike function with the multi-Gaussian gradient function [1]_. |
|
Spike function with the slayer surrogate gradient function. |
|
brainpy.math.random
module: Random Number Generations#
|
Sets a new random seed. |
|
|
|
Random values in a given shape. |
|
Return random integers from low (inclusive) to high (exclusive). |
|
Random integers of type np.int_ between low and high, inclusive. |
|
Return a sample (or samples) from the "standard normal" distribution. |
|
Return random floats in the half-open interval [0.0, 1.0). |
|
Return random floats in the half-open interval [0.0, 1.0). |
|
This is an alias of random_sample. See random_sample for the complete |
|
This is an alias of random_sample. See random_sample for the complete |
|
Generates a random sample from a given 1-D array |
|
Randomly permute a sequence, or return a permuted range. |
|
Modify a sequence in-place by shuffling its contents. |
|
Draw samples from a Beta distribution. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Sample truncated standard normal random values with given shape and dtype. |
|
Sample Bernoulli random values with given shape and mean. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Draw samples from a Weibull distribution. |
|
Sample from a Weibull distribution. |
|
Draw samples from a Zipf distribution. |
|
Sample from a one sided Maxwell distribution. |
|
Sample Student’s t random values. |
|
Sample uniformly from the orthogonal group O(n). |
|
Sample log-gamma random values. |
|
Sample random values from categorical distributions. |
|
Similar to |
|
Similar to |
|
Similar to |
|
RandomState that track the random generator state. |
alias of |
|
RandomState that track the random generator state. |
brainpy.math.linalg
module: Linear algebra#
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
brainpy.math.fft
module: Discrete Fourier Transform#
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
|
Similar to |
brainpy.channels
module#
Basic Channel Classes#
|
Base class for ions. |
|
Base class for ion channels. |
|
The brainpy_object calcium dynamics. |
|
Base class for Ih channel models. |
|
Base class for Calcium ion channels. |
|
Base class for sodium channel. |
|
Base class for potassium channel. |
|
Base class for leaky channel. |
Voltage-dependent Sodium Channel Models#
|
The sodium current model. |
|
The sodium current model described by (Traub and Miles, 1991) [1]_. |
|
The sodium current model described by Hodgkin–Huxley model [1]_. |
Voltage-dependent Potassium Channel Models#
|
The delayed rectifier potassium channel current. |
|
The potassium channel described by (Traub and Miles, 1991) [1]_. |
|
The potassium channel described by Hodgkin–Huxley model [1]_. |
|
The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. |
|
The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. |
|
The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. |
|
The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. |
|
A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. |
Voltage-dependent Calcium Channel Models#
|
Fixed Calcium dynamics. |
|
Calcium ion flow with dynamics. |
|
Dynamical Calcium model proposed. |
|
The first-order calcium concentration model. |
|
The calcium-activated non-selective cation channel model proposed by (Inoue & Strowbridge, 2008) [2]_. |
|
The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. |
|
The low-threshold T-type calcium current model for thalamic reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. |
|
The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. |
|
The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. |
Calcium-dependent Potassium Channel Models#
|
The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. |
Hyperpolarization-activated Cation Channel Models#
|
The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_. |
|
The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_. |
Leakage Channel Models#
|
The leakage channel current. |
|
The potassium leak channel current. |
brainpy.layers
module#
Basic ANN Layer Class#
|
Base class for a layer of artificial neural network. |
Convolutional Layers#
|
One-dimensional convolution. |
|
Two-dimensional convolution. |
|
Three-dimensional convolution. |
alias of |
|
alias of |
|
alias of |
|
|
One dimensional transposed convolution (aka. |
|
Two dimensional transposed convolution (aka. |
|
Three dimensional transposed convolution (aka. |
Dropout Layers#
|
A layer that stochastically ignores a subset of inputs each training step. |
Function Layers#
|
Applies an activation function to the inputs |
|
Flattens a contiguous range of dims into 2D or 1D. |
|
Dense Connection Layers#
|
A linear transformation applied over the last dimension of the input. |
alias of |
|
|
A placeholder identity operator that is argument-insensitive. |
Normalization Layers#
|
1-D batch normalization [1]_. |
|
2-D batch normalization [1]_. |
|
3-D batch normalization [1]_. |
alias of |
|
alias of |
|
alias of |
|
|
Layer normalization (https://arxiv.org/abs/1607.06450). |
|
Group normalization layer. |
|
Instance normalization layer. |
NVAR Layers#
|
Nonlinear vector auto-regression (NVAR) node. |
Pooling Layers#
|
Pools the input by taking the maximum over a window. |
|
Pools the input by taking the minimum over a window. |
|
Pools the input by taking the average over a window. |
|
Applies a 1D average pooling over an input signal composed of several input |
|
Applies a 2D average pooling over an input signal composed of several input |
|
Applies a 3D average pooling over an input signal composed of several input |
|
Applies a 1D max pooling over an input signal composed of several input |
|
Applies a 1D max pooling over an input signal composed of several input |
|
Applies a 1D max pooling over an input signal composed of several input |
|
Adaptive one-dimensional average down-sampling. |
|
Adaptive two-dimensional average down-sampling. |
|
Adaptive three-dimensional average down-sampling. |
|
Adaptive one-dimensional maximum down-sampling. |
|
Adaptive two-dimensional maximum down-sampling. |
|
Adaptive three-dimensional maximum down-sampling. |
Reservoir Layers#
|
Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. |
Artificial Recurrent Layers#
|
Basic fully-connected RNN core. |
|
Gated Recurrent Unit. |
|
Long short-term memory (LSTM) RNN core. |
|
1-D convolutional LSTM. |
|
2-D convolutional LSTM. |
|
3-D convolutional LSTM. |
brainpy.neurons
module#
Biological Models#
|
Hodgkin–Huxley neuron model. |
|
The Morris-Lecar neuron model. |
|
The Pinsky and Rinsel (1994) model. |
|
Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. |
Fractional-order Models#
|
Fractional-order neuron model. |
|
The fractional-order FH-R model [1]_. |
|
Fractional-order Izhikevich model [10]_. |
Reduced Models#
|
Leaky Integrator Model. |
|
Leaky integrate-and-fire neuron model. |
|
Exponential integrate-and-fire neuron model. |
|
Adaptive exponential integrate-and-fire neuron model. |
|
Quadratic Integrate-and-Fire neuron model. |
|
Adaptive quadratic integrate-and-fire neuron model. |
|
Generalized Integrate-and-Fire model. |
|
Leaky Integrate-and-Fire model with SFA [1]_. |
|
The Izhikevich neuron model. |
|
Hindmarsh-Rose neuron model. |
|
FitzHugh-Nagumo neuron model. |
Noise Models#
|
The Ornstein–Uhlenbeck process. |
Input Models#
|
Input neuron group for place holder. |
|
Output neuron group for place holder. |
|
The input neuron group characterized by spikes emitting at given times. |
|
Poisson Neuron Group. |
brainpy.rates
module#
|
|
|
FitzHugh-Nagumo system used in [1]_. |
|
FitzHugh-Nagumo model with recurrent neural feedback. |
|
A mean-field model of a quadratic integrate-and-fire neuron population. |
|
Stuart-Landau model with Hopf bifurcation. |
|
Wilson-Cowan population model. |
|
A threshold linear rate model. |
brainpy.synapses
module#
Abstract Models#
|
Voltage Jump Synapse Model, or alias of Delta Synapse Model. |
|
Exponential decay synapse model. |
|
Dual exponential synapse model. |
|
Alpha synapse model. |
|
NMDA synapse model. |
|
Poisson Input to the given Variable. |
Biological Models#
|
AMPA synapse model. |
|
GABAa synapse model. |
|
Biological NMDA synapse model. |
Coupling Models#
|
Delay coupling. |
|
Diffusive coupling. |
|
Additive coupling. |
Gap Junction Models#
|
Learning Rule Models#
brainpy.synouts
module#
|
Conductance-based synaptic output. |
|
Current-based synaptic output. |
|
Synaptic output based on Magnesium blocking. |
brainpy.synplast
module#
|
Synaptic output with short-term depression. |
|
Synaptic output with short-term plasticity. |
brainpy.integrators
module#
ODE integrators#
Base ODE Integrator#
|
Numerical Integrator for Ordinary Differential Equations (ODEs). |
Generic ODE Functions#
|
Set the default ODE numerical integrator method for differential equations. |
Get the default ODE numerical integrator method. |
|
|
Register a new ODE integrator. |
Get all supported numerical methods for DDEs. |
Explicit Runge-Kutta ODE Integrators#
|
Explicit Runge–Kutta methods for ordinary differential equation. |
|
The Euler method for ODEs. |
|
Explicit midpoint method for ODEs. |
|
Heun's method for ODEs. |
|
Ralston's method for ODEs. |
|
Generic second order Runge-Kutta method for ODEs. |
|
Classical third-order Runge-Kutta method for ODEs. |
|
Heun's third-order method for ODEs. |
|
Ralston's third-order method for ODEs. |
|
Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). |
|
Classical fourth-order Runge-Kutta method for ODEs. |
|
Ralston's fourth-order method for ODEs. |
|
3/8-rule fourth-order method for ODEs. |
Adaptive Runge-Kutta ODE Integrators#
|
Adaptive Runge-Kutta method for ordinary differential equations. |
|
The Fehlberg RK1(2) method for ODEs. |
|
The Runge–Kutta–Fehlberg method for ODEs. |
|
The Dormand–Prince method for ODEs. |
|
The Cash–Karp method for ODEs. |
|
The Bogacki–Shampine method for ODEs. |
|
The Heun–Euler method for ODEs. |
Exponential ODE Integrators#
|
Exponential Euler method using automatic differentiation. |
SDE integrators#
Base SDE Integrator#
|
SDE Integrator. |
Generic SDE Functions#
|
Set the default SDE numerical integrator method for differential equations. |
Get the default SDE numerical integrator method. |
|
|
Register a new SDE integrator. |
Get all supported numerical methods for DDEs. |
Normal SDE Integrators#
|
Euler method for the Ito and Stratonovich integrals. |
|
The Euler-Heun method for Stratonovich integral scheme. |
|
Milstein method for Ito or Stratonovich integrals. |
|
Derivative-free Milstein method for Ito or Stratonovich integrals. |
|
First order, explicit exponential Euler method. |
SRK methods for scalar Wiener process#
|
Order 2.0 weak SRK methods for SDEs with scalar Wiener process. |
|
Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. |
|
FDE integrators#
Base FDE Integrator#
|
Numerical integrator for fractional differential equations (FEDs). |
Generic FDE Functions#
|
Set the default ODE numerical integrator method for differential equations. |
Get the default ODE numerical integrator method. |
|
|
Register a new ODE integrator. |
Get all supported numerical methods for DDEs. |
Methods for Caputo Fractional Derivative#
|
One-step Euler method for Caputo fractional differential equations. |
|
The L1 scheme method for the numerical approximation of the Caputo fractional-order derivative equations [3]_. |
Methods for Riemann-Liouville Fractional Derivative#
|
Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_. |
brainpy.analysis
module#
Low-dimensional Analyzers#
|
Phase plane analyzer for 1D dynamical system. |
|
Phase plane analyzer for 2D dynamical system. |
|
Bifurcation analysis of 1D system. |
|
Bifurcation analysis of 2D system. |
|
|
|
High-dimensional Analyzers#
|
Find fixed/slow points by numerical optimization. |
brainpy.connect
module#
Base Connection Classes and Tools#
|
|
|
|
|
convert a dense matrix to (indices, indptr). |
|
Convert csr to csc. |
|
convert (indices, indptr) to a dense matrix. |
|
|
|
convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu' |
|
Convert csr to csc. |
|
convert (indices, indptr) to a dense matrix. |
Base Synaptic Connector Class. |
|
|
Synaptic connector to build connections between two neuron groups. |
|
Synaptic connector to build synapse connections within a population of neurons. |
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
Custom Connections#
|
Connector built from the dense connection matrix. |
|
Connector built from the |
|
Connector built from the CSR sparse connection matrix. |
|
Connector built from the sparse connection matrix |
Random Connections#
|
Connect the post-synaptic neurons with fixed probability. |
|
Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. |
|
Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. |
|
Connect the synaptic neurons with fixed total number. |
|
Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function. |
|
Connection with a maximum distance under a probability p. |
|
Build a Watts–Strogatz small-world graph. |
|
Build a random graph according to the Barabási–Albert preferential attachment model. |
|
Build a random graph according to the dual Barabási–Albert preferential attachment model. |
|
Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering. |
Regular Connections#
|
Connect two neuron groups one by one. |
|
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. |
|
The nearest four neighbors connection method. |
|
The nearest eight neighbors conn method. |
|
The nearest (2*N+1) * (2*N+1) neighbors conn method. |
Connect two neuron groups one by one. |
|
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. |
|
The nearest four neighbors connection method. |
|
The nearest eight neighbors conn method. |
brainpy.encoding
module#
|
Base class for encoding rate values as spike trains. |
|
Encode the rate input as the spike train. |
|
Encode the rate input into the spike train according to [1]_. |
|
Encode the rate input as the Poisson spike train. |
brainpy.initialize
module#
This module provides methods to initialize weights.
You can access them through brainpy.init.XXX
.
Basic Initialization Classes#
Base Initialization Class. |
Regular Initializers#
|
Zero initializer. |
|
Constant initializer. |
|
One initializer. |
|
Returns the identity matrix. |
Random Initializers#
|
Initialize weights with normal distribution. |
|
Initialize weights with uniform distribution. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Construct an initializer for uniformly distributed orthogonal matrices. |
|
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. |
Decay Initializers#
|
Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function. |
|
Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. |
brainpy.inputs
module#
|
Format an input current with different sections. |
|
Format constant input in durations. |
|
Format current input like a series of short-time spikes. |
|
Get the gradually changed input current. |
|
Stimulus sampled from a Wiener process, i.e. drawn from standard normal distribution N(0, sqrt(dt)). |
|
Ornstein–Uhlenbeck input. |
|
Sinusoidal input. |
|
Oscillatory square input. |
brainpy.losses
module#
Comparison#
|
This criterion combines |
|
Computes the softmax cross-entropy loss. |
|
Computes the sigmoid cross-entropy loss. |
|
Creates a criterion that measures the mean absolute error (MAE) between each element in the logits \(x\) and targets \(y\). |
|
Computes the L2 loss. |
|
Huber loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
|
Binary logistic loss. |
|
Multiclass logistic loss. |
|
Computes sigmoid cross entropy given logits and multiple class labels. |
|
Computes the softmax cross entropy between sets of logits and labels. |
|
Calculates the log-cosh loss for a set of predictions. |
|
Computes CTC loss and CTC forward-probabilities. |
|
Computes CTC loss. |
Regularization#
|
Computes the L2 loss. |
|
Computes the mean absolute error between x and y. |
|
|
|
Calculates the log-cosh loss for a set of predictions. |
|
Apply label smoothing. |
brainpy.measure
module#
|
Calculate cross correlation index between neurons. |
|
Calculate neuronal synchronization via voltage variance. |
|
Pearson correlation of the lower triagonal of two matrices. |
|
Weighted Pearson correlation of two data series. |
|
Functional connectivity matrix of timeseries activities. |
|
Get spike raster plot which displays the spiking activity of a group of neurons over time. |
|
Calculate the mean firing rate over in a neuron group. |
|
A kernel-based method to calculate unitary local field potentials (uLFP) from a network of spiking neurons [1]_. |
brainpy.optim
module#
Optimizers#
|
Base Optimizer Class. |
|
Stochastic gradient descent optimizer. |
|
Momentum optimizer. |
|
Nesterov accelerated gradient optimizer [2]_. |
|
Optimizer that implements the Adagrad algorithm. |
|
Optimizer that implements the Adadelta algorithm. |
|
Optimizer that implements the RMSprop algorithm. |
|
Optimizer that implements the Adam algorithm. |
|
Layer-wise adaptive rate scaling (LARS) optimizer [1]_. |
|
Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_. |
|
Adam with weight decay regularization [1]_. |
Schedulers#
|
|
|
The learning rate scheduler. |
|
|
|
Decays the learning rate of each parameter group by gamma every step_size epochs. |
|
Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. |
|
Set the learning rate of each parameter group using a cosine annealing schedule, where \(\eta_{max}\) is set to the initial lr and \(T_{cur}\) is the number of epochs since the last restart in SGDR: |
|
Set the learning rate of each parameter group using a cosine annealing |
|
Decays the learning rate of each parameter group by gamma every epoch. |
|
|
|
|
|
|
|
brainpy.running
module#
|
Perform a vectorized map of a function by using |
|
Perform a parallelized map of a function by using |
|
Run multiple models in multi-processes. |
|
Run multiple models in multi-processes with lock. |
|
Performs a parallel ordered map with a progress bar. |
|
Performs a parallel unordered map with a progress bar. |
Release notes (brainpy)#
Note
All history release notes please see GitHub releases.
brainpy 2.2.x#
BrainPy 2.2.x is a complete re-design of the framework, tackling the shortcomings of brainpy 2.1.x generation, effectively bringing it to research needs and standards.
Version 2.2.1 (2022.09.09)#
This release fixes bugs found in the codebase and improves the usability and functions of BrainPy.
Bug fixes#
Fix the bug of operator customization in
brainpy.math.XLACustomOp
andbrainpy.math.register_op
. Now, it supports operator customization by using NumPy and Numba interface. For instance,
import brainpy.math as bm
def abs_eval(events, indices, indptr, post_val, values):
return post_val
def con_compute(outs, ins):
post_val = outs
events, indices, indptr, _, values = ins
for i in range(events.size):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
index = indices[j]
old_value = post_val[index]
post_val[index] = values + old_value
event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute)
Fix the bug of
brainpy.tools.DotDict
. Now, it is compatible with the transformations of JAX. For instance,
import brainpy as bp
from jax import vmap
@vmap
def multiple_run(I):
hh = bp.neurons.HH(1)
runner = bp.dyn.DSRunner(hh, inputs=('input', I), numpy_mon_after_run=False)
runner.run(100.)
return runner.mon
mon = multiple_run(bp.math.arange(2, 10, 2))
New features#
Add numpy operators
brainpy.math.mat
,brainpy.math.matrix
,brainpy.math.asmatrix
.Improve translation rules of brainpylib operators, improve its running speeds.
Support
DSView
ofDynamicalSystem
instance. Now, it supports defining models with a slice view of a DS instance. For example,
import brainpy as bp
import brainpy.math as bm
class EINet_V2(bp.dyn.Network):
def __init__(self, scale=1.0, method='exp_auto'):
super(EINet_V2, self).__init__()
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
self.N = bp.neurons.LIF(num_exc + num_inh,
V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
method=method, V_initializer=bp.initialize.Normal(-55., 2.))
# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc], post=self.N,
conn=bp.connect.FixedProb(0.02),
g_max=we, tau=5.,
output=bp.synouts.COBA(E=0.),
method=method)
self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:], post=self.N,
conn=bp.connect.FixedProb(0.02),
g_max=wi, tau=10.,
output=bp.synouts.COBA(E=-80.),
method=method)
net = EINet_V2(scale=1., method='exp_auto')
# simulation
runner = bp.dyn.DSRunner(
net,
monitors={'spikes': net.N.spike},
inputs=[(net.N.input, 20.)]
)
runner.run(100.)
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)
Version 2.2.0 (2022.08.12)#
This release has provided important improvements for BrainPy, including usability, speed, functions, and others.
Backwards Incompatible changes#
brainpy.nn
module is no longer supported and has been removed since version 2.2.0. Instead, users should usebrainpy.train
module for the training of BP algorithms, online learning, or offline learning algorithms, andbrainpy.algorithms
module for online / offline training algorithms.The
update()
function for the model definition has been changed:
>>> # 2.1.x
>>>
>>> import brainpy as bp
>>>
>>> class SomeModel(bp.dyn.DynamicalSystem):
>>> def __init__(self, ):
>>> ......
>>> def update(self, t, dt):
>>> pass
>>> # 2.2.x
>>>
>>> import brainpy as bp
>>>
>>> class SomeModel(bp.dyn.DynamicalSystem):
>>> def __init__(self, ):
>>> ......
>>> def update(self, tdi):
>>> t, dt = tdi.t, tdi.dt
>>> pass
where tdi
can be defined with other names, like sha
, to represent the shared argument across modules.
Deprecations#
brainpy.dyn.xxx (neurons)
andbrainpy.dyn.xxx (synapse)
are no longer supported. Please usebrainpy.neurons
,brainpy.synapses
modules.brainpy.running.monitor
has been removed.brainpy.nn
module has been removed.
New features#
brainpy.math.Variable
receives abatch_axis
setting to represent the batch axis of the data.
>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros((1, 4, 5)), batch_axis=0)
>>> a.value = bm.zeros((2, 4, 5)) # success
>>> a.value = bm.zeros((1, 2, 5)) # failed
MathError: The shape of the original data is (2, 4, 5), while we got (1, 2, 5) with batch_axis=0.
brainpy.train
providesbrainpy.train.BPTT
for back-propagation algorithms,brainpy.train.Onlinetrainer
for online training algorithms,brainpy.train.OfflineTrainer
for offline training algorithms.brainpy.Base
class supports_excluded_vars
setting to ignore variables when retrieving variables by usingBase.vars()
method.
>>> class OurModel(bp.Base):
>>> _excluded_vars = ('a', 'b')
>>> def __init__(self):
>>> super(OurModel, self).__init__()
>>> self.a = bm.Variable(bm.zeros(10))
>>> self.b = bm.Variable(bm.ones(20))
>>> self.c = bm.Variable(bm.random.random(10))
>>>
>>> model = OurModel()
>>> model.vars().keys()
dict_keys(['OurModel0.c'])
brainpy.analysis.SlowPointFinder
supports directly analyzing an instance ofbrainpy.dyn.DynamicalSystem
.
>>> hh = bp.neurons.HH(1)
>>> finder = bp.analysis.SlowPointFinder(hh, target_vars={'V': hh.V, 'm': hh.m, 'h': hh.h, 'n': hh.n})
brainpy.datasets
supports MNIST, FashionMNIST, and other datasets.Supports defining conductance-based neuron models``.
>>> class HH(bp.dyn.CondNeuGroup):
>>> def __init__(self, size):
>>> super(HH, self).__init__(size)
>>>
>>> self.INa = channels.INa_HH1952(size, )
>>> self.IK = channels.IK_HH1952(size, )
>>> self.IL = channels.IL(size, E=-54.387, g_max=0.03)
brainpy.layers
module provides commonly used models for DNN and reservoir computing.Support composable definition of synaptic models by using
TwoEndConn
,SynOut
,SynSTP
andSynLTP
.
>>> bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob),
>>> g_max=0.03 / scale, tau=5,
>>> output=bp.synouts.COBA(E=0.),
>>> stp=bp.synplast.STD())
Provide commonly used surrogate gradient function for spiking generation, including
brainpy.math.spike_with_sigmoid_grad
brainpy.math.spike_with_linear_grad
brainpy.math.spike_with_gaussian_grad
brainpy.math.spike_with_mg_grad
Provide shortcuts for GPU memory management via
brainpy.math.disable_gpu_memory_preallocation()
, andbrainpy.math.clear_buffer_memory()
.
What’s Changed#
fix #207: synapses update first, then neurons, finally delay variables by @chaoming0625 in #219
new version of brainpy: V2.2.0-rc1 by @chaoming0625 in #226
update training apis by @chaoming0625 in #227
Update quickstart and the analysis module by @c-xy17 in #229
Eseential updates for montors, analysis, losses, and examples by @chaoming0625 in #230
Integrated simulation, simulaton and analysis by @chaoming0625 in #232
update docs by @chaoming0625 in #233
unify
brainpy.layers
with other modules inbrainpy.dyn
by @chaoming0625 in #234fix bugs by @chaoming0625 in #235
update apis, docs, examples and others by @chaoming0625 in #236
fixes by @chaoming0625 in #237
updates by @chaoming0625 in #240
update training docs by @chaoming0625 in #241
change doc path/organization by @chaoming0625 in #242
Update advanced docs by @chaoming0625 in #243
update quickstart docs & enable jit error checking by @chaoming0625 in #244
update apis and examples by @chaoming0625 in #245
update apis and tests by @chaoming0625 in #246
version 2.2.0 by @chaoming0625 in #248
add norm and pooling & fix bugs in operators by @ztqakita in #249
Full Changelog: V2.1.12…V2.2.0
brainpy 2.1.x#
Version 2.1.12 (2022.05.17)#
Highlights#
This release is excellent. We have made important improvements.
We provide dozens of random sampling in NumPy which are not supportted in JAX, such as
brainpy.math.random.bernoulli
,brainpy.math.random.lognormal
,brainpy.math.random.binomial
,brainpy.math.random.chisquare
,brainpy.math.random.dirichlet
,brainpy.math.random.geometric
,brainpy.math.random.f
,brainpy.math.random.hypergeometric
,brainpy.math.random.logseries
,brainpy.math.random.multinomial
,brainpy.math.random.multivariate_normal
,brainpy.math.random.negative_binomial
,brainpy.math.random.noncentral_chisquare
,brainpy.math.random.noncentral_f
,brainpy.math.random.power
,brainpy.math.random.rayleigh
,brainpy.math.random.triangular
,brainpy.math.random.vonmises
,brainpy.math.random.wald
,brainpy.math.random.weibull
make efficient checking on numerical values. Instead of direct
id_tap()
checking which has large overhead, currentlybrainpy.tools.check_erro_in_jit()
is highly efficient.Fix
JaxArray
operator errors onNone
improve oo-to-function transformation speeds
io
works:.save_states()
and.load_states()
What’s Changed#
support dtype setting in array interchange functions by [@chaoming0625](chaoming0625) in #209
fix #144: operations on None raise errors by [@chaoming0625](chaoming0625) in #210
add tests and new functions for random sampling by [@c-xy17](c-xy17) in #213
feat: fix
io
for brainpy.Base by [@chaoming0625](chaoming0625) in #211update advanced tutorial documentation by [@chaoming0625](chaoming0625) in #212
fix #149 (dozens of random samplings in NumPy) and fix JaxArray op errors by [@chaoming0625](chaoming0625) in #216
feat: efficient checking on numerical values by [@chaoming0625](chaoming0625) in #217
Full Changelog: V2.1.11…V2.1.12
Version 2.1.11 (2022.05.15)#
What’s Changed#
update apis, test and docs of numpy ops by @chaoming0625 in #202
update control flow, integrators, operators, and docs by @chaoming0625 in #205
improve oo-to-function transformation speed by @chaoming0625 in #208
Full Changelog: V2.1.10…V2.1.11
Version 2.1.10 (2022.05.05)#
What’s Changed#
update control flow APIs and Docs by @chaoming0625 in #192
doc: update docs of dynamics simulation by @chaoming0625 in #193
fix #125: add channel models and two-compartment Pinsky-Rinzel model by @chaoming0625 in #194
JIT errors do not change Variable values by @chaoming0625 in #195
Functionalinaty improvements by @chaoming0625 in #197
update rate docs by @chaoming0625 in #198
update brainpy.dyn doc by @chaoming0625 in #199
Full Changelog: V2.1.8…V2.1.10
Version 2.1.8 (2022.04.26)#
What’s Changed#
Fix #120 by @chaoming0625 in #178
feat: brainpy.Collector supports addition and subtraction by @chaoming0625 in #179
feat: delay variables support “indices” and “reset()” function by @chaoming0625 in #180
Support reset functions in neuron and synapse models by @chaoming0625 in #181
update()
function on longer need_t
and_dt
by @chaoming0625 in #183small updates by @chaoming0625 in #188
feat: easier control flows with
brainpy.math.ifelse
by @chaoming0625 in #189feat: update delay couplings of
DiffusiveCoupling
andAdditiveCouping
by @chaoming0625 in #190update version and changelog by @chaoming0625 in #191
Full Changelog: V2.1.7…V2.1.8
Version 2.1.7 (2022.04.22)#
What’s Changed#
synapse models support heterogeneuos weights by @chaoming0625 in #170
more efficient synapse implementation by @chaoming0625 in #171
fix input models in brainpy.dyn by @chaoming0625 in #172
update README: ‘brain-py’ to ‘brainpy’ by @chaoming0625 in #174
fix: fix the updating rules in the STP model by @c-xy17 in #176
Updates and fixes by @chaoming0625 in #177
Full Changelog: V2.1.5…V2.1.7
Version 2.1.5 (2022.04.18)#
What’s Changed#
brainpy.math.random.shuffle
is numpy like by @chaoming0625 in #153update LICENSE by @chaoming0625 in #155
compatible apis of ‘brainpy.math’ with those of ‘jax.numpy’ in most modules by @chaoming0625 in #156
Important updates by @chaoming0625 in #157
Updates by @chaoming0625 in #159
Add LayerNorm, GroupNorm, and InstanceNorm as nn_nodes in normalization.py by @c-xy17 in #162
update setup.py by @chaoming0625 in #165
update synapses by @chaoming0625 in #167
get the deserved name: brainpy by @chaoming0625 in #168
update tests by @chaoming0625 in #169
Full Changelog: V2.1.4…V2.1.5
Version 2.1.4 (2022.04.04)#
What’s Changed#
fix doc parsing bug by @chaoming0625 in #127
Reorganization of
brainpylib.custom_op
and adding interface inbrainpy.math
by @ztqakita in #128Fix: modify
register_op
and brainpy.math interface by @ztqakita in #130new features about RNN training and delay differential equations by @chaoming0625 in #132
Fix #123: Add low-level operators docs and modify register_op by @ztqakita in #134
fix #133, support batch size training with offline algorithms by @chaoming0625 in #136
fix #84: support online training algorithms by @chaoming0625 in #137
fix: fix shape checking error by @chaoming0625 in #139
solve #131, support efficient synaptic computation for special connection types by @chaoming0625 in #140
feat: update the API and test for batch normalization by @c-xy17 in #142
Node is default trainable by @chaoming0625 in #143
Updates training apis and docs by @chaoming0625 in #145
fix: add dependencies and update version by @ztqakita in #147
update requirements by @chaoming0625 in #146
data pass of the Node is default SingleData by @chaoming0625 in #148
Full Changelog: V2.1.3…V2.1.4
Version 2.1.3 (2022.03.27)#
This release improves the functionality and usability of BrainPy. Core changes include
support customization of low-level operators by using Numba
fix bugs
What’s Changed#
Provide custom operators written in numba for jax jit by @ztqakita in #122
fix DOGDecay bugs, add more features by @chaoming0625 in #124
fix bugs by @chaoming0625 in #126
Full Changelog : V2.1.2…V2.1.3
Version 2.1.2 (2022.03.23)#
This release improves the functionality and usability of BrainPy. Core changes include
support rate-based whole-brain modeling
add more neuron models, including rate neurons/synapses
support Python 3.10
improve delays etc. APIs
What’s Changed#
fix matplotlib dependency on “brainpy.analysis” module by @chaoming0625 in #110
add py3.6 test & delete multiple macos env by @ztqakita in #112
update python version by @chaoming0625 in #114
Enhance measure/input/brainpylib by @chaoming0625 in #117
fix #105: Add customize connections docs by @ztqakita in #118
fix bugs by @chaoming0625 in #119
Whole brain modeling by @chaoming0625 in #121
Full Changelog: V2.1.1…V2.1.2
Version 2.1.1 (2022.03.18)#
This release continues to update the functionality of BrainPy. Core changes include
numerical solvers for fractional differential equations
more standard
brainpy.nn
interfaces
New Features#
- Numerical solvers for fractional differential equations
brainpy.fde.CaputoEuler
brainpy.fde.CaputoL1Schema
brainpy.fde.GLShortMemory
- Fractional neuron models
brainpy.dyn.FractionalFHR
brainpy.dyn.FractionalIzhikevich
support
shared_kwargs
in RNNTrainer and RNNRunner
Version 2.1.0 (2022.03.14)#
Highlights#
We are excited to announce the release of BrainPy 2.1.0. This release is composed of nearly 270 commits since 2.0.2, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.1.0 updates are focused on improving usability, functionality, and stability of BrainPy. Highlights of version 2.1.0 include:
New module
brainpy.dyn
for dynamics building and simulation. It is composed of many neuron models, synapse models, and others.New module
brainpy.nn
for neural network building and training. It supports to define reservoir models, artificial neural networks, ridge regression training, and back-propagation through time training.New module
brainpy.datasets
for convenient dataset construction and initialization.New module
brainpy.integrators.dde
for numerical integration of delay differential equations.Add more numpy-like operators in
brainpy.math
module.Add automatic continuous integration on Linux, Windows, and MacOS platforms.
Fully update brainpy documentation.
Fix bugs on
brainpy.analysis
andbrainpy.math.autograd
Incompatible changes#
Remove
brainpy.math.numpy
module.Remove numba requirements
Remove matplotlib requirements
Remove steps in
brainpy.dyn.DynamicalSystem
Remove travis CI
New Features#
brainpy.ddeint
for numerical integration of delay differential equations, the supported methods include:Euler
MidPoint
Heun2
Ralston2
RK2
RK3
Heun3
Ralston3
SSPRK3
RK4
Ralston4
RK4Rule38
- set default int/float/complex types
brainpy.math.set_dfloat()
brainpy.math.set_dint()
brainpy.math.set_dcomplex()
- Delay variables
brainpy.math.FixedLenDelay
brainpy.math.NeutralDelay
- Dedicated operators
brainpy.math.sparse_matmul()
More numpy-like operators
Neural network building
brainpy.nn
Dynamics model building and simulation
brainpy.dyn
Version 2.0.2 (2022.02.11)#
There are important updates by Chaoming Wang in BrainPy 2.0.2.
provide
pre2post_event_prod
operatorsupport array creation from a list/tuple of JaxArray in
brainpy.math.asarray
andbrainpy.math.array
update
brainpy.ConstantDelay
, add.latest
and.oldest
attributesadd
brainpy.IntegratorRunner
support for efficient simulation of brainpy integratorssupport auto finding of RandomState when JIT SDE integrators
fix bugs in SDE
exponential_euler
methodmove
parallel
running APIs intobrainpy.simulation
add
brainpy.math.syn2post_mean
,brainpy.math.syn2post_softmax
,brainpy.math.pre2post_mean
andbrainpy.math.pre2post_softmax
operators
Version 2.0.1 (2022.01.31)#
Today we release BrainPy 2.0.1. This release is composed of over 70 commits since 2.0.0, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.0.0 updates are focused on improving documentation and operators. Core changes include:
Improve
brainpylib
operatorsComplete documentation for programming system
Add more numpy APIs
Add
jaxfwd
in autograd moduleAnd other changes
Version 2.0.0.1 (2022.01.05)#
Add progress bar in
brainpy.StructRunner
Version 2.0.0 (2021.12.31)#
Start a new version of BrainPy.
Highlight#
We are excited to announce the release of BrainPy 2.0.0. This release is composed of over 260 commits since 1.1.7, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.0.0 updates are focused on improving performance, usability and consistence of BrainPy.
All the computations are migrated into JAX. Model building
, simulation
, training
and analysis
are all based on JAX. Highlights of version 2.0.0 include:
brainpylib are provided to dedicated operators for brain dynamics programming
Connection APIs in
brainpy.conn
module are more efficient.Update analysis tools for low-dimensional and high-dimensional systems in
brainpy.analysis
module.Support more general Exponential Euler methods based on automatic differentiation.
Improve the usability and consistence of
brainpy.math
module.Remove JIT compilation based on Numba.
Separate brain building with brain simulation.
Incompatible changes#
remove
brainpy.math.use_backend()
remove
brainpy.math.numpy
moduleno longer support
.run()
inbrainpy.DynamicalSystem
(see New Features)remove
brainpy.analysis.PhasePlane
(see New Features)remove
brainpy.analysis.Bifurcation
(see New Features)remove
brainpy.analysis.FastSlowBifurcation
(see New Features)
New Features#
- Exponential Euler method based on automatic differentiation
brainpy.ode.ExpEulerAuto
- Numerical optimization based low-dimensional analyzers:
brainpy.analysis.PhasePlane1D
brainpy.analysis.PhasePlane2D
brainpy.analysis.Bifurcation1D
brainpy.analysis.Bifurcation2D
brainpy.analysis.FastSlow1D
brainpy.analysis.FastSlow2D
- Numerical optimization based high-dimensional analyzer:
brainpy.analysis.SlowPointFinder
- Dedicated operators in
brainpy.math
module: brainpy.math.pre2post_event_sum
brainpy.math.pre2post_sum
brainpy.math.pre2post_prod
brainpy.math.pre2post_max
brainpy.math.pre2post_min
brainpy.math.pre2syn
brainpy.math.syn2post
brainpy.math.syn2post_prod
brainpy.math.syn2post_max
brainpy.math.syn2post_min
- Dedicated operators in
- Conversion APIs in
brainpy.math
module: brainpy.math.as_device_array()
brainpy.math.as_variable()
brainpy.math.as_jaxarray()
- Conversion APIs in
- New autograd APIs in
brainpy.math
module: brainpy.math.vector_grad()
- New autograd APIs in
- Simulation runners:
brainpy.ReportRunner
brainpy.StructRunner
brainpy.NumpyRunner
- Commonly used models in
brainpy.models
module brainpy.models.LIF
brainpy.models.Izhikevich
brainpy.models.AdExIF
brainpy.models.SpikeTimeInput
brainpy.models.PoissonInput
brainpy.models.DeltaSynapse
brainpy.models.ExpCUBA
brainpy.models.ExpCOBA
brainpy.models.AMPA
brainpy.models.GABAa
- Commonly used models in
Naming cache clean:
brainpy.clear_name_cache
add safe in-place operations of
update()
method and.value
assignment for JaxArray
Documentation#
Complete tutorials for quickstart
Complete tutorials for dynamics building
Complete tutorials for dynamics simulation
Complete tutorials for dynamics training
Complete tutorials for dynamics analysis
Complete tutorials for API documentation
brainpy 1.1.x#
If you are using brainpy==1.x
, you can find documentation, examples, and models through the following links:
Documentation: https://brainpy.readthedocs.io/en/brainpy-1.x/
Examples from papers: https://brainpy-examples.readthedocs.io/en/brainpy-1.x/
Canonical brain models: https://brainmodels.readthedocs.io/en/brainpy-1.x/
Version 1.1.7 (2021.12.13)#
fix bugs on
numpy_array()
conversion in brainpy.math.utils module
Version 1.1.5 (2021.11.17)#
API changes:
fix bugs on ndarray import in brainpy.base.function.py
convenient ‘get_param’ interface brainpy.simulation.layers
add more weight initialization methods
Doc changes:
add more examples in README
Version 1.1.4#
API changes:
add
.struct_run()
in DynamicalSystemadd
numpy_array()
conversion in brainpy.math.utils moduleadd
Adagrad
,Adadelta
,RMSProp
optimizersremove setting methods in brainpy.math.jax module
remove import jax in brainpy.__init__.py and enable jax setting, including
enable_x64()
set_platform()
set_host_device_count()
enable
b=None
as no bias in brainpy.simulation.layersset int_ and float_ as default 32 bits
remove
dtype
setting in Initializer constructor
Doc changes:
add
optimizer
in “Math Foundation”add
dynamics training
docsimprove others
Version 1.1.3#
fix bugs of JAX parallel API imports
fix bugs of post_slice structure construction
update docs
Version 1.1.2#
add
pre2syn
andsyn2post
operatorsadd verbose and check option to
Base.load_states()
fix bugs on JIT DynamicalSystem (numpy backend)
Version 1.1.1#
fix bugs on symbolic analysis: model trajectory
change absolute access in the variable saving and loading to the relative access
add UnexpectedTracerError hints in JAX transformation functions
Version 1.1.0 (2021.11.08)#
This package releases a new version of BrainPy.
Highlights of core changes:
math
module#
support numpy backend
support JAX backend
support
jit
,vmap
andpmap
on class objects on JAX backendsupport
grad
,jacobian
,hessian
on class objects on JAX backendsupport
make_loop
,make_while
, andmake_cond
on JAX backendsupport
jit
(based on numba) on class objects on numpy backendunified numpy-like ndarray operation APIs
numpy-like random sampling APIs
FFT functions
gradient descent optimizers
activation functions
loss function
backend settings
base
module#
Base
for whole Version ecosystemFunction
to wrap functionsCollector
andTensorCollector
to collect variables, integrators, nodes and others
integrators
module#
class integrators for ODE numerical methods
class integrators for SDE numerical methods
simulation
module#
support modular and composable programming
support multi-scale modeling
support large-scale modeling
support simulation on GPUs
fix bugs on
firing_rate()
remove
_i
inupdate()
function, replace_i
with_dt
, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)reimplement the
input_step
andmonitor_step
in a more intuitive waysupport to set dt in the single object level (i.e., single instance of DynamicSystem)
common used DNN layers
weight initializations
refine synaptic connections
brainpy 1.0.x#
Version 1.0.3 (2021.08.18)#
Fix bugs on
firing rate measurement
stability analysis
Version 1.0.2#
This release continues to improve the user-friendliness.
Highlights of core changes:
Remove support for Numba-CUDA backend
Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).
Add the output message of the step function running error.
More powerful support for Monitoring
More powerful support for running order scheduling
Remove unsqueeze() and squeeze() operations in
brainpy.ops
Add reshape() operation in
brainpy.ops
Improve docs for numerical solvers
Improve tests for numerical solvers
Add keywords checking in ODE numerical solvers
Add more unified operations in brainpy.ops
Support “@every” in steps and monitor functions
Fix ODE solver bugs for class bounded function
Add build phase in Monitor
Version 1.0.1#
Fix bugs
Version 1.0.0#
NEW VERSION OF BRAINPY
Change the coding style into the object-oriented programming
Systematically improve the documentation
brainpy 0.x#
Version 0.3.5#
Add ‘timeout’ in sympy solver in neuron dynamics analysis
Reconstruct and generalize phase plane analysis
Generalize the repeat mode of
Network
to different running duration between two runsUpdate benchmarks
Update detailed documentation
Version 0.3.1#
Add a more flexible way for NeuState/SynState initialization
Fix bugs of “is_multi_return”
Add “hand_overs”, “requires” and “satisfies”.
Update documentation
Auto-transform range to numba.prange
Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models
Version 0.3.0#
Computation API#
Rename “brainpy.numpy” to “brainpy.backend”
Delete “pytorch”, “tensorflow” backends
Add “numba” requirement
Add GPU support
Profile setting#
Delete “backend” profile setting, add “jit”
Core systems#
Delete “autopepe8” requirement
Delete the format code prefix
Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”
Change the “ST” declaration out of “requires”
Add “repeat” mode run in Network
Change “vector-based” to “mode” in NeuType and SynType definition
Package installation#
Remove “pypi” installation, installation now only rely on “conda”
Version 0.2.4#
API changes#
Fix bugs
Version 0.2.3#
API changes#
Add “animate_1D” in
visualization
moduleAdd “PoissonInput”, “SpikeTimeInput” and “FreqInput” in
inputs
moduleUpdate phase_portrait_analyzer.py
Models and examples#
Add CANN examples
Version 0.2.2#
API changes#
Redesign visualization
Redesign connectivity
Update docs
Version 0.2.1#
API changes#
Fix bugs in numba import
Fix bugs in numpy mode with scalar model
Version 0.2.0#
API changes#
For computation:
numpy
,numba
For model definition:
NeuType
,SynConn
For model running:
Network
,NeuGroup
,SynConn
,Runner
For numerical integration:
integrate
,Integrator
,DiffEquation
For connectivity:
One2One
,All2All
,GridFour
,grid_four
,GridEight
,grid_eight
,GridN
,FixedPostNum
,FixedPreNum
,FixedProb
,GaussianProb
,GaussianWeight
,DOG
For visualization:
plot_value
,plot_potential
,plot_raster
,animation_potential
For measurement:
cross_correlation
,voltage_fluctuation
,raster_plot
,firing_rate
For inputs:
constant_current
,spike_current
,ramp_current
.
Models and examples#
Neuron models:
HH model
,LIF model
,Izhikevich model
Synapse models:
AMPA
,GABA
,NMDA
,STP
,GapJunction
Network models:
gamma oscillation