BrainPy documentation
Contents
BrainPy documentation#
BrainPy is a highly flexible and extensible framework targeting on the high-performance Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:
JIT compilation and automatic differentiation for class objects.
Numerical methods for ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), fractional differential equations (FDEs), etc.
Dynamics simulation tools for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more.
Dynamics training tools with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.
Dynamics analysis tools for differential equations, including phase plane analysis, bifurcation analysis, linearization analysis, and fixed/slow point finding.
And more others ……
Comprehensive examples of BrainPy please see:
BrainPyExamples: https://brainpy-examples.readthedocs.io/
The code of BrainPy is open-sourced at GitHub:
Installation#
BrainPy
is designed to run cross platforms, including Windows,
GNU/Linux, and OSX. It only relies on Python libraries.
Installation with pip#
You can install BrainPy
from the pypi.
To do so, use:
pip install brainpy
To update the BrainPy version, you can use
pip install -U brainpy
If you want to install the pre-release version (the latest development version) of BrainPy, you can use:
pip install --pre brainpy
Installation from source#
If you decide not to use conda
or pip
, you can install BrainPy
from
GitHub,
or OpenI.
To do so, use:
pip install git+https://github.com/PKU-NIP-Lab/BrainPy
# or
pip install git+https://git.openi.org.cn/OpenI/BrainPy
Dependency 1: NumPy#
In order to make BrainPy work normally, users should install several dependent Python packages.
The basic function of BrainPy
only relies on NumPy, which is very
easy to install through pip
or conda
:
pip install numpy
# or
conda install numpy
Dependency 2: JAX#
BrainPy relies on JAX. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend.
Linux & MacOS#
Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of JAX for Linux and macOS systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html .
To install a CPU-only version of JAX, you can run
pip install --upgrade "jax[cpu]"
If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Alternatively, you can download the preferred release “.whl” file for jaxlib, and install it via pip
:
pip install xxxx.whl
pip install jax
Warning
For m1 macOS users, you should run your python environment on Apple
silicon instead of intel
silicon since rosetta2
cannot translate jaxlib
. One suggestion is uninstall miniconda3 and install
miniforge3 for managing your python environment.
Windows#
For Windows users, JAX can be installed by the following methods:
Method 1: There are several community supported Windows build for jax, please refer to the github link for more details: https://github.com/cloudhan/jax-windows-builder . Simply speaking, the provided binary releases of JAX for Windows are available at https://whls.blob.core.windows.net/unstable/index.html .
You can download the preferred release “.whl” file, and install it via
pip
:
pip install xxxx.whl
pip install jax
Method 2: For Windows 10+ system, you can use Windows Subsystem for Linux (WSL). The installation guide can be found in WSL Installation Guide for Windows 10. Then, you can install JAX in WSL just like the installation step in Linux/MacOs.
Method 3: You can also build JAX from source.
Other Dependency#
In order to get full supports of BrainPy, we recommend you install the following packages:
Numba: needed in some NumPy-based computations
pip install numba
# or
conda install numba
brainpylib: needed in dedicated operators
pip install brainpylib
matplotlib: required in some visualization functions, but now it is recommended that users explicitly import matplotlib for visualization
pip install matplotlib
# or
conda install matplotlib
NetworkX: needed in the visualization of network training
pip install networkx
# or
conda install networkx
Simulating a Spiking Neural Network#
Spiking neural networks (SNN) are one of the most important tools to study brian dynamcis in computational neuroscience. They simulate the biological processes of information transmission in the brain, including the change of membrane potentials, neuronal firing, and synaptic transmission. In this section, we will illustrate how to build and simulate a SNN.
Before we start, the BrainPy package should be imported:
import brainpy as bp
bp.math.set_platform('cpu')
Building an E-I balance network#
Let’s try to build a E-I balance network. The structure of a E-I balance network is as follows:

A E-I balance network is composed of two neuron groups and the synaptic connections between them. Specifically, they include:
a group of excitatory neurons (E),
a group of inhibitory neurons (I),
synaptic connections within the excitatory and inhibitory neuron groups, respectively, and
the inter-connections between these two groups.
To construct the network, we need to define these components one by one. BrainPy provides plenty of handy built-in models for brain dynamic simulation. They are contained in brainpy.dyn
. Let’s choose the simplest yet the most canonical neuron model, the Leaky Integrate-and-Fire (LIF) model, to build the excitatory and inhibitory neuron groups:
E = bp.dyn.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto')
I = bp.dyn.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto')
E.V[:] = bp.math.random.randn(3200) * 2 - 60.
I.V[:] = bp.math.random.randn(800) * 2 - 60.
When defining the LIF neuron group, the parameters can be tuned according to users’ need. The first parameter denotes the number of neurons. Here the ratio of excitatory and inhibitory neurons is set 4:1. V_rest
denotes the resting potential, V_th
denotes the firing threshold, V_reset
denotes the reset value after firing, tau
is the time constant, and tau_ref
is the duration of the refractory period. method
refers to the numerical integration method to be used in simulation.
Then the synaptic connections between these two groups can be defined as follows:
E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5., method='exp_auto')
E2I = bp.dyn.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5., method='exp_auto')
I2E = bp.dyn.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10., method='exp_auto')
I2I = bp.dyn.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10., method='exp_auto')
Here we use the Expnential synapse model (ExpCOBA
) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use bp.conn.FixedProb
, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in Synaptic Connection). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model.
After defining all the components, they can be combined to form a network:
net = bp.dyn.Network(E2E, E2I, I2E, I2I, E=E, I=I)
In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (E
and I
) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).
We have successfully constructed an E-I balance network by using BrainPy’s biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, brainpy.dyn.Network()
is a simple example of customizing a network model. Please refer to Dynamic Simulation for more information.
Running a simulation#
After build a SNN, we can use it for dynamic simulation. To run a simulation, we need first wrap the network model into a runner. Currently BrainPy provides DSRunner
and ReportRunner
in brainpy.dyn
, which will be expanded in the Runners tutorial. Here we use DSRunner
as an example:
runner = bp.dyn.DSRunner(net,
monitors=['E.spike', 'I.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)],
dt=0.1)
To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the spike
variable in the E
and I
LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration dt
(with the default value of 0.1) can also be specified.
After creating the runner, we can run a simulation by calling the runner:
runner(100)
0.779956579208374
where the calling function receives the simulation time (usually in milliseconds) as the input and returns the time (seconds) spent on simulation. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to Just-In-Time Compilation for more details.
The simulation results are stored as NumPy arrays in the monitors, and can be visualized easily:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4.5))
plt.subplot(121)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=False)
plt.subplot(122)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], show=True)

In the code above, brianpy.visualize
contains some useful functions to visualize simulation results based on the matplotlib
package. Since the simulation results are stored as NumPy arrays, users can directly use matplotlib
for visualization.
Building a decision making network#
Simulating a Firing Rate Network Model#
Whole-brain modeling is the grand challenge of computational neuroscience. Simulating a whole-brain models with spiking neurons is still nearly impossible for normal users. However, by using rate-based neural mass models, in which each brain region is approximated to several simple variables, we can build an abstract whole-brain model. In recent years, whole-brain models can be used to address a wide range of problems. In this section, we are going to talk about how to simulate a whole-brain neural mass model with BrainPy.
import brainpy as bp
import brainpy.math as bm
from brainpy.dyn import rates
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'plasma'
Neural mass model#
A neural mass models is a low-dimensional population model of spiking neural networks. It aims to describe the coarse grained activity of large populations of neurons and synapses. Mathematically, it is a dynamical system of non-linear ODEs. A classical neural mass model is the two dimensional Wilson–Cowan model. This model tracks the activity of an excitatory population of neurons coupled to an inhibitory population. With the augmentation of such models by more realistic forms of synaptic and network interaction they have proved especially successful in providing fits to neuro-imaging data.
Here, let’s try the Wilson-Cowan model.
wc = rates.WilsonCowanModel(2,
wEE=16., wIE=15., wEI=12., wII=3.,
E_a=1.5, I_a=1.5, E_theta=3., I_theta=3.,
method='exp_euler_auto')
wc.x[:] = [-0.2, 1.]
wc.y[:] = [0.0, 1.]
runner = bp.dyn.DSRunner(wc, monitors=['x', 'y'], inputs=['input', -0.5])
runner.run(10.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.x,
plot_ids=[0, 1], legend='e', show=True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We can see this model at least has two stable states.
Bifurcation diagram
With the automatic analysis module in BrainPy, we can easily inspect the bifurcation digram of the model. Bifurcation diagrams can give us an overview of how different parameters of the model affect its dynamics (the details of the automatic analysis support of BrainPy please see the introduction in Analyzing a Dynamical Model and tutorials in Dynamics Analysis). In this case, we make x_ext
as a bifurcation parameter, and try to see how the system behavior changes with the change of x_ext
.
bf = bp.analysis.Bifurcation2D(
wc,
target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
target_pars={'x_ext': [-2, 2]},
pars_update={'y_ext': 0.},
resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 40000 candidates
I am trying to filter out duplicate fixed points ...
Found 579 fixed points.
I am plotting the limit cycle ...
C:\Users\adadu\miniconda3\lib\site-packages\jax\_src\numpy\lax_numpy.py:1868: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax_internal._check_user_dtype_supported(dtype, "asarray")


Similarly, simulating and analyzing a rate-based FitzHugh-Nagumo model is also a piece of cake by using BrainPy.
fhn = rates.FHN(1, method='exp_auto')
bf = bp.analysis.Bifurcation2D(
fhn,
target_vars={'x': [-2, 2], 'y': [-2, 2]},
target_pars={'x_ext': [0, 2]},
pars_update={'y_ext': 0.},
resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 20000 candidates
I am trying to filter out duplicate fixed points ...
Found 200 fixed points.
I am plotting the limit cycle ...
C:\Users\adadu\miniconda3\lib\site-packages\jax\_src\numpy\lax_numpy.py:1868: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax_internal._check_user_dtype_supported(dtype, "asarray")


In this model, we find that when the external input x_ext
has the value in [0.72, 1.4], the model will generate limit cycles. We can verify this by simulation.
runner = bp.dyn.DSRunner(fhn, monitors=['x', 'y'], inputs=['input', 1.0])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x')
bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y', show=True)

Whole-brain model#
A rate-based whole-brain model is a network model which consists of coupled brain regions. Each brain region is represented by a neural mass model which is connected to other brain regions according to the underlying network structure of the brain, also known as the connectome. In order to illustrate how to use BrainPy’s support for whole-brain modeling, here we provide a processed data in the following link:
A processed data from ConnectomeDB of the Human Connectome Project (HCP): https://share.weiyun.com/wkPpARKy
Please download the dataset and place it in your favorite PATH
.
PATH = './data/hcp.npz'
In genral, a dataset for whole-brain modeling consists of the following parts:
1. A structural connectivity matrix which captures the synaptic connection strengths between brain areas. It often derived from DTI tractography of the whole brain. The connectome is then typically parcellated in a preferred atlas (for example the AAL2 atlas) and the number of axonal fibers connecting each brain area with every other area is counted. This number serves as an indication of the synaptic coupling strengths between the areas of the brain.
2. A delay matrix which calculated from the average length of the axonal fibers connecting each brain area with another.
3. A set of functional data that can act as a target for model optimization. Resting-state fMRI offers an easy and fairly unbiased way for calibrating whole-brain models. EEG data could be used as well.
Now, let’s load the dataset.
data = bm.load(PATH)
# The structural connectivity matrix
data['Cmat'].shape
(80, 80)
# The fiber length matrix
data['Dmat'].shape
(80, 80)
# The functional data for 7 subjects
data['FCs'].shape
(7, 80, 80)
Let’s have a look what the data looks like.
fig, axs = plt.subplots(1, 3, figsize=(15,5))
fig.subplots_adjust(wspace=0.28)
im = axs[0].imshow(data['Cmat'])
axs[0].set_title("Connection matrix")
fig.colorbar(im, ax=axs[0],fraction=0.046, pad=0.04)
im = axs[1].imshow(data['Dmat'], cmap='inferno')
axs[1].set_title("Fiber length matrix")
fig.colorbar(im, ax=axs[1],fraction=0.046, pad=0.04)
im = axs[2].imshow(data['FCs'][0], cmap='inferno')
axs[2].set_title("Empirical FC of subject 1")
fig.colorbar(im, ax=axs[2],fraction=0.046, pad=0.04)
plt.show()

Let’s first get the delay matrix according to the fiber length matrix, the signal transmission speed between areas, and the numerical integration step dt
. Here, we assume the axonal transmission speed is 20 and the simulation time step dt=0.1
ms.
sigal_speed = 20.
# the number of the delay steps
delay_mat = data['Dmat'] / sigal_speed / bm.get_dt()
delay_mat = bm.asarray(delay_mat, dtype=bm.int_)
The connectivity matrix can be directly obtained through the structural connectivity matrix, which times a global coupling strength parameter gc
. b
gc = 1.
conn_mat = bm.asarray(data['Cmat'] * gc)
# It is necessary to exclude the self-connections
bm.fill_diagonal(conn_mat, 0)
We now are ready to intantiate a whole-brain model with the neural mass model and the dataset the processed before.
class WholeBrainNet(bp.dyn.Network):
def __init__(self, Cmat, Dmat):
super(WholeBrainNet, self).__init__()
self.fhn = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01,
name='fhn', method='exp_auto')
self.syn = rates.DiffusiveCoupling(self.fhn.x, self.fhn.x, self.fhn.input,
conn_mat=Cmat,
delay_steps=Dmat.astype(bm.int_),
initial_delay_data=bp.init.Uniform(0, 0.05))
def update(self, _t, _dt):
self.syn.update(_t, _dt)
self.fhn.update(_t, _dt)
net = WholeBrainNet(conn_mat, delay_mat)
runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)
6.346501350402832
The simulated results can be used to estimate the functional correlation matrix.
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
fc = bp.measure.functional_connectivity(runner.mon['fhn.x'])
ax = axs[0].imshow(fc)
plt.colorbar(ax, ax=axs[0])
axs[1].plot(runner.mon.ts, runner.mon['fhn.x'][:, ::5], alpha=0.8)
plt.tight_layout()
plt.show()

We can compute the element-wise Pearson correlation of the functional connectivity matrices of the simulated data to the empirical data to estimate how well the model captures the inter-areal functional correlations found in empirical resting-state recordings.
scores = [bp.measure.matrix_correlation(fc, fcemp)
for fcemp in data['FCs']]
print("Correlation per subject:", [f"{s:.2}" for s in scores])
print("Mean FC/FC correlation: {:.2f}".format(bm.mean(bm.asarray(scores))))
Correlation per subject: ['0.58', '0.45', '0.55', '0.49', '0.54', '0.5', '0.45']
Mean FC/FC correlation: 0.51
Training a Recurrent Neural Network#
In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy porvides various interfaces to help users train dynamical systems.
import brainpy as bp
import brainpy.math as bm
bm.enable_x64()
bm.set_dfloat(bm.float64)
# bm.set_platform('cpu')
import matplotlib.pyplot as plt
General usage#
In BrainPy, we provide a general interface to build neural networks, supporting feedforward, recurrent, feedback connections.
Model Building#
In general, each model is treated as a node. Based on the node operations, like feedforward >>
, feedback <<
, etc., we can create arbitrary node graph we want. For example,
feedforward_net = data >> reservoir >> readout
create a simple network in which data
first feedforward to reservoir
node, then the output of reservoir
is readout by a readout
node. Further, if we try to create a feedback connection from readout
to reservoir
, we can use
feedback_net = reservoir << readout
After merging it with the previous defined feedforward_net
, we can create a network with feedforward and feedback connections:
model = feedforward_net & feedback_net
Model running & training#
Moreover, BrainPy provides various interfaces for network running and training, including the commonly used Ridge Regression method, FORCE learning method, and back-progropagation through time algorithms. Users can create these runners and trainers with the following codes:
runner = bp.nn.RNNRunner(model, ...)
or,
trainer = bp.nn.RidgeTrainer(model, ...)
trainer = bp.nn.FORCELearning(model, ...)
trainer = bp.nn.BPTT(model, ...)
Bellow, we demonstrate these supports with several examples.
Echo state network#
We first illustrate the training interface of BrainPy using an echo state network.
For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.

# create the components we need
i = bp.nn.Input(3)
r = bp.nn.Reservoir(400, spectral_radius=1.4)
o = bp.nn.LinearReadout(3)
# create the model we need
model = i >> r >> o
model.plot_node_graph(fig_size=(5, 5), node_size=2000)

We use this created network to predict the chaotic time series, named as Lorenz attractor. Particurlaly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.
dt = 0.01
data = bp.datasets.lorenz_series(100, dt=dt)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['x'].flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['y'].flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['z'].flatten()))
plt.ylabel('z')
plt.show()

def get_subset(data, start, end):
res = {'x': data['x'][start: end],
'y': data['y'][start: end],
'z': data['z'][start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
return res.reshape((1, ) + res.shape)
To complish this task, we use Ridge Regression method to train the network. Before that, we first initialize the network with the batch size of 1, and then construct a Ridge Regression trainer.
model.initialize(num_batch=1)
trainer = bp.nn.RidgeTrainer(model, beta=1e-6)
We warm-up the network with 20 ms.
warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)
outs.shape
(1, 2000, 3)
The training data is the time series from 20 ms to 80 ms. We want the network has the abilitty to forecast 1 time step ahead.
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)
trainer.fit([x_train, y_train])
Then we test the trained network with the next 20 ms.
x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))
predictions = trainer.predict(x_test)
bp.losses.mean_squared_error(y_test, predictions)
DeviceArray(0.00014552, dtype=float64)
def plot_difference(truths, predictions):
truths = truths.numpy()
predictions = predictions.numpy()
plt.subplot(311)
plt.plot(truths[0, :, 0], label='Ground Truth')
plt.plot(predictions[0, :, 0], label='Prediction')
plt.ylabel('x')
plt.legend()
plt.subplot(312)
plt.plot(truths[0, :, 1], label='Ground Truth')
plt.plot(predictions[0, :, 1], label='Prediction')
plt.ylabel('y')
plt.legend()
plt.subplot(313)
plt.plot(truths[0, :, 2], label='Ground Truth')
plt.plot(predictions[0, :, 2], label='Prediction')
plt.ylabel('z')
plt.legend()
plt.show()
plot_difference(y_test, predictions)

We can make the task harder to forecast 10 time step ahead.
warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])
x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)
plot_difference(y_test, predictions)

Or forecast 100 time step ahead.
warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)
x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])
x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)
plot_difference(y_test, predictions)

As you see, forecasting larger time step makes the learning more difficult.
Next generation RC#
(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).
(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.
In BrainPy, we can easily implement this kind of network. Here, let’s try to use NG-RC to infer the \(z\) variable according to \(x\) and \(y\) variables. This task is important for applications where it is possible to obtain high-quality information about a dynamical variable in a laboratory setting, but not in field deployment.
Let’s first initialize the data we need.
dt = 0.02
t_warmup = 10. # ms
t_train = 20. # ms
t_test = 50. # ms
num_warmup = int(t_warmup / dt) # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)
lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test,
dt=dt,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
def get_subset(data, start, end):
res = {'x': data['x'][start: end],
'y': data['y'][start: end],
'z': data['z'][start: end]}
X = bm.hstack([res['x'], res['y']])
X = X.reshape((1,) + X.shape)
Y = res['z']
Y = Y.reshape((1, ) + Y.shape)
return X, Y
X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup)
X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
X_test, Y_test = get_subset(lorenz_series, num_warmup + num_train, num_warmup + num_train + num_test)
The network architecture is the same with the above echo state network. Specifically, we have an input node, a reservoir node and an output node. To accomplish this task, (Gauthier, et. al., Nature Communications, 2021) used 4 delay history information with stride of 5, and their quadratic polynomial monomials. Therefore, we create the network as:
i = bp.nn.Input(2)
r = bp.nn.NVAR(delay=4, order=2, stride=5)
o = bp.nn.LinearReadout(1, trainable=True)
model = i >> r >> o
model.initialize(num_batch=1)
We train the network using the Ridge Regression method too.
trainer = bp.nn.RidgeTrainer(model, beta=0.05)
# warm-up
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
# training
trainer.fit([X_train, Y_train])
# prediction
outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
Warmup NMS: 10729.250973138222
Prediction NMS: 0.3374043793562189
X_test = bm.asarray(X_test).numpy()[0]
Y_test = bm.asarray(Y_test).numpy().flatten()
outputs = bm.asarray(outputs).numpy().flatten()
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(X_test[:, 0], color='b')
plt.ylabel('x')
plt.subplot(312)
plt.plot(X_test[:, 1], color='b')
plt.ylabel('y')
plt.subplot(313)
plt.plot(Y_test, color='b', label='Grund Truth')
plt.plot(outputs, color='r', label='Prediction')
plt.ylabel('y')
plt.legend()
plt.show()

Recurrent neural network#
In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.nn.BPTT
method.
Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we has a time series of noise data,
noises = bm.random.normal(0, 0.2, size=10)
plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt
:
dt = 0.1
integrals = bm.cumsum(noises) * dt
plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()

Here, we first define a task which generates the input data and the target integration results.
from functools import partial
dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128
@partial(bm.jit,
dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
static_argnames=['batch_size'])
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
# Create the white noise input
sample = bm.random.normal(size=(batch_size, 1, 1))
bias = mean * 2.0 * (sample - 0.5)
samples = bm.random.normal(size=(batch_size, num_step, 1))
noise_t = scale / dt ** 0.5 * samples
inputs = bias + noise_t
targets = bm.cumsum(inputs, axis=1)
return inputs, targets
def train_data():
for _ in range(100):
yield build_inputs_and_targets(batch_size=num_batch)
Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True
for the used VanillaRNN
instance.
model = (
bp.nn.Input(1)
>>
bp.nn.VanillaRNN(100, state_trainable=True)
>>
bp.nn.Dense(1)
)
model.initialize(num_batch=num_batch)
brainpy.nn.BPTT
trainer receives a loss
function setting, and an optimizer
setting. Loss function can be selected from the brainpy.losses
module, or it can be a callable function receives (predictions, targets)
argument. Optimizer setting must be an instance of brainpy.optim.Optimizer
.
Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.
# define loss function
def loss(predictions, targets, l2_reg=2e-4):
mse = bp.losses.mean_squared_error(predictions, targets)
l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.nn.BPTT(model,
loss=loss,
optimizer=opt,
max_grad_norm=5.0)
# train the model
trainer.fit(train_data,
num_batch=num_batch,
num_train=30,
num_report=500)
Train 500 steps, use 9.3755 s, train loss 0.03093
Train 1000 steps, use 6.7661 s, train loss 0.0275
Train 1500 steps, use 6.9309 s, train loss 0.02998
Train 2000 steps, use 6.6827 s, train loss 0.02409
Train 2500 steps, use 6.6528 s, train loss 0.02289
Train 3000 steps, use 6.6663 s, train loss 0.02187
The training losses is recorded in the .train_losses
attribute.
plt.figure(figsize=(8, 3))
plt.plot(trainer.train_losses.numpy())
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()

Finally, let’s try the trained network, and test whether it can generate the correct integration results.
model.initialize(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()

Further reading#
More about Node specifications, please see Node Specification.
Details about Node operations, please see Node Operations.
Want to customize a Node, please see Node Customization.
More examples of training recurrent neural networks, please BrainPy Examples.
Analyzing a Dynamical Model#
In BrainPy, defined models can not only be used for simulation, but also to perform automatic dynamics analysis.
BrainPy provides rich interfaces to support analysis, incluing
Phase plane analysis, bifurcation analysis, and fast-slow bifurcation analysis for low-dimensional systems;
linearization analysis and fixed/slow point finding for high-dimensional systems.
Here we will introduce three brief examples of 1-D bifurcation analysis and 2-D phase plane analysis. For more detailsand more examples, please refer to the tutorials of dynamics analysis.
import brainpy as bp
bp.math.set_platform('cpu')
bp.math.enable_x64() # Dynamics analysis in BrainPy requires 64-bit computation
Example 1: bifurcation analysis of a 1D model#
Here, we demonstrate how to perform a bifurcation analysis through a one-dimensional neuron model.
Let's try to analyze how the external input influences the dynamics of the Exponential Integrate-and-Fire (ExpIF) model. The ExpIF model is a one-variable neuron model whose dynamics is defined by:
$$
\tau {\dot {V}}= - (V - V_\mathrm{rest}) + \Delta_T \exp(\frac{V - V_T}{\Delta_T}) + RI \\
\mathrm{if}\, \, V > \theta, \quad V \gets V_\mathrm{reset}
$$
We can analyze the change of \({\dot {V}}\) with respect to \(V\). First, let’s generate an ExpIF model using pre-defined modules in brainpy.dyn
:
expif = bp.dyn.ExpIF(1, delta_T=1.)
The default value of other parameters can be accessed directly by their names:
expif.V_rest, expif.V_T, expif.R, expif.tau
(-65.0, -59.9, 1.0, 10.0)
After defining the model, we can use it for bifurcation analysis.
bif = bp.analysis.Bifurcation1D(
model=expif,
target_vars={'V': [-70., -55.]},
target_pars={'I_ext': [0., 6.]},
resolutions=0.01
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...

In the Bifurcation1D
analyzer, model
refers to the modelto be analyzed (essentially the analyzer will access the derivative function in the model), target_vars
denotes the target variables, target_pars
denotes the changing parameters, and resolution
determines the resolutioin of the analysis.
In the image above, there are two lines that “merge” together to form a bifurcation. The dots making up the lines refer to the fixed points of \(\mathrm{d}V/\mathrm{d}t\). On the left of the bifurcation point (where two lines merge together), there are two fixed points where \(\mathrm{d}V/\mathrm{d}t = 0\) given each external input \(I_\mathrm{ext}\). One of them is a stable point, and the other is an unstable one. When \(I_\mathrm{ext}\) increases, the two fixed points move closer to each other, overlap, and finally disappear.
Bifurcation analysis provides insights for the dynamics of the model, for it indicates the number and the change of stable states with respect to different parameters.
Example 2: phase plane analysis of a 2D model#
Besides bifurcationi analysis, another important tool is phase plane analysis, which displays the trajectory of the variable point in the vector field. Let’s take the FitzHugh–Nagumo (FHN) neuron model model as an example. The dynamics of the FHN model is given by:
Users can easily define a FHN model which is also provided by BrainPy:
fhn = bp.dyn.FHN(1)
Because there are two variables, \(v\) and \(w\), in the FHN model, we shall use 2-D phase plane analysis to visualize how these two variables change over time.
analyzer = bp.analysis.PhasePlane2D(
model=fhn,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
pars_update={'I_ext': 0.8},
resolutions=0.01,
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
analyzer.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 866 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 V=-0.2738719079879798, w=0.5329731346879486 is a unstable node.
I am plotting the trajectory ...

In the PhasePlane2D
analyzer, the parameters model
, target_vars
, and resolution
is the same as those in Bifurcation1D
. pars_update
specifies the parameters to be updated during analysis. After defining the analyzer, users can visualize the nullcline, vector field, fixed points and the trajectory in the image. The phase plane gives users intuitive interpretation of the changes of \(v\) and \(w\) guided by the vector field (violet arrows).
Further reading#
For more details about how to perform bifurcation analysis and phase plane analysis, please see the tutorial of Low-dimensional Analyzers.
A good example of phase plane analysis and bifurcation analysis is the decision-making model, please see the tutorial in Analysis of a Decision-making Model
If you want to how to analyze the slow points (or fixed points) of your high-dimensional dynamical models, please see the tutorial of High-dimensional Analyzers
Math Basics#
brainpy.math
Overview#
The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!
Excellent JIT compilers such as JAX and Numba are provided in Python. While they are designed to work only on pure Python functions, most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on class
in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.
In order to provide a platform can satisfy the need for brain dynamics programming, we provide the brainpy.math module.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
import numpy as np
Why use brainpy.math
?#
Specifically, brainpy.math
makes the following contributions:
1. Numpy-like ndarray.#
Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray
structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y
. To overcome these drawbacks, brainpy.math
provides JaxArray
that can be used in the same way as numpy ndarray.
# ndarray in "numpy"
a = np.arange(5)
a
array([0, 1, 2, 3, 4])
a[0] += 5
a
array([5, 1, 2, 3, 4])
# ndarray in "brainpy.math"
b = bm.arange(5)
b
JaxArray([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
JaxArray([5, 1, 2, 3, 4], dtype=int32)
For more details, please see the Tensors tutorial.
2. Numpy-like random sampling.#
JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math
provides brainpy.math.random
for random sampling just like the numpy.random
module. For example:
# random sampling in "numpy"
np.random.seed(12345)
np.random.random(5)
array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])
np.random.normal(0., 2., 5)
array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])
# random sampling in "brainpy.math.random"
bm.random.seed(12345)
bm.random.random(5)
JaxArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ], dtype=float32)
bm.random.normal(0., 2., 5)
JaxArray([-1.5375282, -0.5970201, -2.272839 , 3.233081 , -0.2738593], dtype=float32)
For more details, please see the Tensors tutorial.
3. JAX transformations on class objects.#
OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math
extends JAX transformations to Python classess.
Example 1: JIT compilation performed on class objects.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr1 = LogisticRegression(num_dim)
%timeit lr1(points, labels)
255 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr2 = bm.jit(LogisticRegression(num_dim))
%timeit lr2(points, labels)
162 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Example 2: Autograd performed on variables of a class object.
class Linear(bp.Base):
def __init__(self, num_hidden, num_input, **kwargs):
super(Linear, self).__init__(**kwargs)
# parameters
self.num_input = num_input
self.num_hidden = num_hidden
# variables
self.w = bm.random.random((num_input, num_hidden))
self.b = bm.zeros((num_hidden,))
def __call__(self, x):
r = x @ self.w + self.b
return r.mean()
l = Linear(num_hidden=3, num_input=2)
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))
(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
[0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32),
DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32))
What is the difference between brainpy.math
and other frameworks?#
brainpy.math
is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.
However, there are important differences between brainpy.math
and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary, brainpy.math
allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called Objax which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming.
Tensors and Variables#
In this section ,we will briefly introduce two basic and important data structures: tensors and variables. They form the foundation for mathematical operations of brain dynamics programming (BDP) in BrainPy.
Tensors#
Definition and Attributes#
A tensor is a data structure that organizes algebraic objects in a multidimentional vector space. Simply speaking, in BrainPy, a tensor is a multidimensional array that contains the same type of data, most commonly of the numeric or boolean type.
The dimensions of an array are called axes. In the following illustration, the 1-D array ([7, 2, 9, 10]
) only has one axis. There are 4 elements in this axis, so the shape of the array is (4,)
.
By contrast, the 2-D array in the illustration has 2 axes. The first axis is of length 2 and the second of length 3. Therefore, the shape of the 2-D array is (2, 3)
.
Similarly, the 3-D array has 3 axes, with the dimensions (4, 3, 2)
in each axis, respectively.

To enable tensor operations, users should import the brainpy.math
module:
import brainpy.math as bm
# bm.set_platform('cpu')
t1 = bm.array([[[0, 1, 2, 3], [1, 2, 3, 4], [4, 5, 6, 7]],
[[0, 0, 0, 0], [-1, 1, -1, 1], [2, -2, 2, -2]]])
t1
JaxArray([[[ 0, 1, 2, 3],
[ 1, 2, 3, 4],
[ 4, 5, 6, 7]],
[[ 0, 0, 0, 0],
[-1, 1, -1, 1],
[ 2, -2, 2, -2]]], dtype=int32)
Here we create a 3-dimensional tensor with the shape of (2, 3, 4) and the type of int32
. Tensors created by brainpy.math
will be stored in JaxArray
, for their future operations will be accelerated by just-in-time (JIT) compilation.
A tensor has several important attributes:
.ndim: the number of axes (dimensions) of the tensor.
.shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be
(n,m)
. The length of the shape tuple is therefore the number of axes,ndim
..size: the total number of elements of the tensor. This is equal to the product of the elements of shape.
.dtype: an object describing the type of the elements in the tensor. One can create or specify dtypes using standard Python types.
print('t1.ndim: {}'.format(t1.ndim))
print('t1.shape: {}'.format(t1.shape))
print('t1.size: {}'.format(t1.size))
print('t1.dtype: {}'.format(t1.dtype))
t1.ndim: 3
t1.shape: (2, 3, 4)
t1.size: 24
t1.dtype: int32
Below we will give a few examples of tensor operations that are commonly used in brain dynamics programming. For more details about tensor operations, please refer to the tensor tutorial.
Creating a tensor#
t2 = bm.arange(4)
# t2: JaxArray([0, 1, 2, 3], dtype=int32)
t3 = bm.ones((2, 4)) * 1.5
# t3: JaxArray([[1.5, 1.5, 1.5, 1.5],
# [1.5, 1.5, 1.5, 1.5]], dtype=float32)
Tensor operations#
# indexing and slicing
t3[1]
# DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)
t3[:, 2:]
# DeviceArray([[1.5, 1.5],
# [1.5, 1.5]], dtype=float32)
DeviceArray([[1.5, 1.5],
[1.5, 1.5]], dtype=float32)
# algebraic operations
t2 + t3[0]
# JaxArray([1.5, 2.5, 3.5, 4.5], dtype=float32)
t3[0] / t1[0, 1]
# DeviceArray([1.5 , 0.75 , 0.5 , 0.375], dtype=float32)
# broadcasting
t2 + 3
# JaxArray([3, 4, 5, 6], dtype=int32)
t2 + t3
# JaxArray([[1.5, 2.5, 3.5, 4.5],
# [1.5, 2.5, 3.5, 4.5]], dtype=float32)
JaxArray([[1.5, 2.5, 3.5, 4.5],
[1.5, 2.5, 3.5, 4.5]], dtype=float32)
# some functions
bm.dot(t2, t3.T)
# JaxArray([9., 9.], dtype=float32)
bm.max(t1, axis=2)
# JaxArray([[3, 4, 7],
# [0, 1, 2]], dtype=int32)
t3.flatten()
# JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)
JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)
In BrainPy, tensors can be used to store some parameters related to dynamical models. For example, if we define a group of Integrate-and-Fire (LIF) neurons and wish to assign each neuron with a different time constant \(\tau\), then we can generate a tensor containing an array of time constants.
n = 6 # assume there are 6 LIF neurons
tau = bm.random.randn(n)*2. + 20.
tau
JaxArray([18.485964, 19.765427, 15.078529, 21.210836, 17.134335,
21.495173], dtype=float32)
Through the code above, a group of time constants is generated from a normal distribution, with a mean of 20 and a variance of 2.
Variables#
We have talked about the definition, operations, and application of tensors in BrainPy. There are some situations, however, where tensors are not applicable. Due to JIT compilation, once a tensor is given to the JIT compiler, the values inside the tensor cannot be changed. This gives rise to severe limitations, because some properties of the dynamical system, such as the membrane potential, dynamically changes over time. Therefore, we need a new data structure to store such dynamic variables, and that is brainpy.math.Variable
.
brainpy.math.Variable
#
brainpy.math.Variable
is a pointer referring to a tensor. The tensor is stored as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.
To create or change a tensor into a variable, users just need to wrap the tensor into brainpy.math.Variable
:
v = bm.Variable(t2)
v
Variable([0, 1, 2, 3], dtype=int32)
Note that the array is contained in a “Variable” instead of a “JaxArray”.
Note
Tensors that are not marked as Variables will be JIT compiled as static data. In JIT compilation, it is shown that modifications of tensors are invalid in a JIT-compilation environment.
Users can access the value in the Variable through its attribute .value
:
v.value
DeviceArray([0, 1, 2, 3], dtype=int32)
Since the data inside a Variable is a tensor, common operations on tensors can be directly grafted to Variables.
In-place updating#
Though the operations are the same, there are some requirements for updating a Variable. If we directly change a Variable, The returning data will become a tensor but not a Variable.
v2 = v + 2
v2
JaxArray([2, 3, 4, 5], dtype=int32)
To update the Variable, users are required to use in-place updating, which only modifies the value inside the Variable but does not change the reference pointing to the Variable. In-place updating operations include:
1. Indexing and slicing
Indexing:
v[i] = a
Slicing:
v[i:j] = b
Slicing the specific values:
v[[1, 3]] = c
Slicing all values,
v[:] = d
,v[...] = e
for more details, please refer to Array Objects Indexing.
v[0] = 10
v[1:3] = 9
v
Variable([10, 9, 9, 3], dtype=int32)
2. Augmented assignment
+=
(add)-=
(subtract)/=
(divide)*=
(multiply)//=
(floor divide)%=
(modulo)**=
(power)&=
(and)|=
(or)^=
(xor)<<=
(left shift)>>=
(right shift)
v -= 3
v <<= 1
v
Variable([14, 12, 12, 0], dtype=int32)
3. .value
assignment
v.value = bm.arange(4)
v
Variable([0, 1, 2, 3], dtype=int32)
.value
assignment directly accesses the data stored in the JaxArray. When using .value
, the new data should be of the same type and shape as the original ones.
try:
v.value = bm.array([1., 1., 1., 0.])
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
4. .update()
method
This method will also check if the new data is of the same type and shape as the original ones.
v.update(bm.array([3, 4, 5, 6]))
v
Variable([3, 4, 5, 6], dtype=int32)
For more details, such as the subtypes of Variables and more information about in-place updating, please see the advanced tutorial for Variables.
Just-In-Time Compilation#
One of the core ideas of BrainPy is Just-In-Time (JIT) compilation. JIT compilation enables Python codes to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine-code speed, which will not only compensate for the time spent for code transformation but also save more time. Therefore, it is necessary to understand how to code in a JIT compatible environment.
This section will briefly introduce JIT compilation and its relation to BrainPy. For more details such as the JIT mechanism in BrainPy, please refer to the advanced Compilation tutorial.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
JIT Compilation for Functions#
To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit()
to instruct BrainPy to transform Python code into machine code.
Take the pure functions as an example. Here we try to implement a function of Gaussian Error Linear Unit:
def gelu(x):
sqrt = bm.sqrt(2 / bm.pi)
cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
y = x * cdf
return y
Let’s first try to run the function without JIT.
x = bm.random.random(100000)
%timeit gelu(x)
295 µs ± 3.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
After JIT compilation, the function significantly speeds up.
gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
JIT Compilation for Objects#
JIT compilation for functions is not enough for brain dynamics programming, since a multitude of dynamic variables and differential equations in a large system would make computation surprisingly complicated. Therefore, BrainPy enables JIT compilation to be performed on class objects, as long as users comply with the following rules:
The class object must be a subclass of brainpy.Base.
Dynamically changed variables must be labeled as brainpy.math.Variable.
Variable updating must be accomplished by in-place operations.
Below is a simple example of a Logistic regression classifier. When wrapped into bm.jit()
, the class oject will be JIT compiled.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w.value = self.w - u
In this example, the model weights (self.w
) will be modified during training, so it is marked as bm.Variable
. If not, in the compilation phase, all self.
accessed variables which are not the instances of bm.Variable
will be compiled as static constants.
import time
def benckmark(model, points, labels, num_iter=30, name=''):
t0 = time.time()
for i in range(num_iter):
model(points, labels)
print(f'{name} used time {time.time() - t0} s')
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# without JIT
lr1 = LogisticRegression(num_dim)
benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 10.024710893630981 s
# with JIT
lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)
benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 5.015154838562012 s
From the above example, we can appreciate the acceleration of JIT compilation. This example, however, is too simplified to show the great difference between running with and without JIT. In fact, in a large brain model, the acceleration brought by JIT compilation is usually far more significant.
Automatic JIT Compilation in Runners#
In a large dynamical system where a large number of neurons and synapses are defined, it would be a little tedious to explicitly wrap every object into bm.jit()
. Fortunately, in most conditions, users do not need to call bm.jit()
, as BrainPy will make JIT compilation automatically.
BrainPy provides a brainpy.Runner
class that is inherited by various runners used in simulation, traning and integration. When initializing it, a runner receives a parameter named jit
, which is set True
by default. This suggests that Runner
will automatically JIT compile the target oject as long as it is wrapped into the runner.
For example, when users perform dynamic simulation on a HH model, they first need to wrap the model into a simulation runner:
model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.))
runner(1000) # running 1000 ms
0.6139698028564453
Where model
is wrapped into a runner, and it will be JIT compiled during simulation.
If users do not want to use JIT compilation (JIT compilation prohibits Python debugging), they can turn it of by setting jit=False
:
model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.), jit=False)
runner(1000)
258.76088523864746
The output is the time (s) spent on simulation. We can see that the simulation is much slower without JIT compilation.
Besides simulation, runners are also used by integrators and trainers. For more details, please refer to the tutorial of runners.
Control Flows#
Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.
Python has two types of control structures:
Selection: used for decisions and branching.
Repetition: used for looping, i.e., repeating a piece of code multiple times.
In this section, we are going to talk about how to build effective control flows with BrainPy and JAX.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
1. Selection#
In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:
if-else
nested if
if-elif-else
Non-Variable
-based control statements#
Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,
class OddEven(bp.Base):
def __init__(self, type_=1):
super(OddEven, self).__init__()
self.type_ = type_
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.type_ == 1:
self.a += 1
elif self.type_ == 2:
self.a -= 1
else:
raise ValueError(f'Unknown type: {self.type_}')
return self.a
In the above example, the target statement in if (statement)
syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.
model = bm.jit(OddEven(type_=1))
model()
Variable([1.], dtype=float32)
model = bm.jit(OddEven(type_=2))
model()
Variable([-1.], dtype=float32)
try:
model = bm.jit(OddEven(type_=3))
model()
except ValueError as e:
print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable
-based control statements#
However, if the statement
target in a if ... else ...
syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.
class OddEvenCauseError(bp.Base):
def __init__(self):
super(OddEvenCauseError, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.rand < 0.5: self.a += 1
else: self.a -= 1
return self.a
wrong_model = bm.jit(OddEvenCauseError())
try:
wrong_model()
except Exception as e:
print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: This problem may be caused by several ways:
1. Your if-else conditional statement relies on instances of brainpy.math.Variable.
2. Your if-else conditional statement relies on functional arguments which do not set in "static_argnames" when applying JIT compilation. More details please see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
3. The static variables which set in the "static_argnames" are provided as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].
To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):
brainpy.math.where: return element-wise conditional comparison results.
brainpy.math.ifelse: Conditional statements of
if-else
, orif-elif-else
, … for a scalar-typed value.
brainpy.math.where
#
where(condition, x, y)
function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.
a = 1.
bm.where(a < 0, 0., 1.)
JaxArray(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
JaxArray([1., 0., 0., 1., 1.], dtype=float32, weak_type=True)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
JaxArray([[0., 0., 1.],
[1., 1., 0.],
[0., 0., 0.]], dtype=float32, weak_type=True)
For the above example, we can rewrite it by using where
syntax as:
class OddEvenWhere(bp.Base):
def __init__(self):
super(OddEvenWhere, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.where(self.rand < 0.5, 1., -1.)
return self.a
model = bm.jit(OddEvenWhere())
model()
Variable([-1.], dtype=float32)
brainpy.math.ifelse
#
Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.
In its simplest case, brainpy.math.ifelse(condition, branches, operands, dyn_vars=None)
is equivalent to:
def ifelse(condition, branches, operands, dyn_vars=None):
true_fun, false_fun = branches
if condition:
return true_fun(operands)
else:
return false_fun(operands)
Based on this function, we can rewrite the above example by using cond
syntax as:
class OddEvenCond(bp.Base):
def __init__(self):
super(OddEvenCond, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.ifelse(self.rand[0] < 0.5,
[lambda _: 1., lambda _: -1.])
return self.a
model = bm.jit(OddEvenCond())
model()
Variable([1.], dtype=float32)
If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands, dyn_vars=None)
can also help you accomplish this easily. Actually, multiple branching case is equivalent to:
def ifelse(conditions, branches, operands, dyn_vars=None):
pred1, pred2, ... = conditions
func1, func2, ..., funcN = branches
if pred1:
return func1(operands)
elif pred2:
return func2(operands)
...
else:
return funcN(operands)
For example, if you have the following code:
def f(a):
if a > 10:
return 1.
elif a > 5:
return 2.
elif a > 0:
return 3.
elif a > -5:
return 4.
else:
return 5.
It can be expressed as:
def f(a):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[1., 2., 3., 4., 5.])
f(11.)
DeviceArray(1., dtype=float32, weak_type=True)
f(6.)
DeviceArray(2., dtype=float32, weak_type=True)
f(1.)
DeviceArray(3., dtype=float32, weak_type=True)
f(-4.)
DeviceArray(4., dtype=float32, weak_type=True)
f(-6.)
DeviceArray(5., dtype=float32, weak_type=True)
A more complex example is:
def f2(a, x):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[lambda x: x*2,
2.,
lambda x: x**2 -1,
lambda x: x - 4.,
5.],
operands=x)
f2(11, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(6, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(1, 1.)
DeviceArray(0., dtype=float32, weak_type=True)
f2(-4, 1.)
DeviceArray(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
DeviceArray(5., dtype=float32, weak_type=True)
If instances of brainpy.math.Variable
are used in branching functions, you can declare them in the dyn_vars
argument.
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])
print('a:', a)
print('b:', b)
a: Variable([1., 1.], dtype=float32)
b: Variable([0., 0.], dtype=float32)
2. Repetition#
A repetition statement is used to repeat a group(block) of programming instructions.
In Python, we generally have two loops/repetitive statements:
for loop: Execute a set of statements once for each item in a sequence.
while loop: Execute a block of statements repeatedly until a given condition is satisfied.
Pythonic loop syntax#
Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,
class LoopSimple(bp.Base):
def __init__(self):
super(LoopSimple, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
for s in self.seq:
self.res += s
return self.res.value
import time
def measure_time(f):
t0 = time.time()
r = f()
t1 = time.time()
print(f'Result: {r}, Time: {t1 - t0}')
model = bm.jit(LoopSimple())
# First time will trigger compilation
measure_time(model)
Result: [501.74673], Time: 2.7157142162323
# Second running
measure_time(model)
Result: [1003.49347], Time: 0.0
When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.
JAX has provided several important loop syntax, including:
BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable
. Specifically, they are:
In this section, we only talk about how to use our provided loop functions.
brainpy.math.make_loop()
#
brainpy.math.make_loop()
is used to generate a for-loop function when you use Variable
.
Suppose that you are using several JaxArrays (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, dyn_vars, out_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
results = body_fun(x)
ys.append([out_vars, results])
return ys
In BrainPy, you can define this logic using brainpy.math.make_loop()
:
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)
hist_of_out_vars = loop_fun(xs)
Or,
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)
hist_of_out_vars, hist_of_return_vars = loop_fun(xs)
For the above example, we can rewrite it by using brainpy.math.make_loop
as:
class LoopStruct(bp.Base):
def __init__(self):
super(LoopStruct, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def add(s): self.res += s
self.loop = bm.make_loop(add, dyn_vars=[self.res])
def __call__(self):
self.loop(self.seq)
return self.res.value
model = bm.jit(LoopStruct())
# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 0.028011560440063477
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0
brainpy.math.make_while()
#
brainpy.math.make_while()
is used to generate a while-loop function when you use JaxArray
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.make_while()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which does not support returned values:
while cond_fun(x):
body_fun(x)
where x
is the external input that is not iterated. All the iterated variables should be marked as JaxArray
. All JaxArray
s used in cond_fun
and body_fun
should be declared as dyn_vars
variables.
Let’s look an example:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f(x):
return i[0] < 10
def body_f(x):
i.value += 1.
counter.value += i
loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)
Dynamics Simulation#
Dynamical System Specification#
BrainPy enables modularity programming and easy model debugging. To build a complex brain dynamics model, you just need to group its building blocks. In this section, we are going to talk about what building blocks we have provided, and how to use these building blocks.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
Models in brainpy.dyn
#
brainpy.dyn
has provided many convenient neuron, synapse, and other models for users. The following figure is a glimpse of the provided models.
The arrows in the graph represent the inheritance relations between different models.
New models will be continuously updated in the page of API documentation.
Initializing a neuron model#
All neuron models implemented in brainpy are subclasses of brainpy.dyn.NeuGroup
. The initialization of a neuron model just needs to provide the geometry size of neurons in a population group.
hh = bp.dyn.HH(size=1) # only 1 neuron
hh = bp.dyn.HH(size=10) # 10 neurons in a group
hh = bp.dyn.HH(size=(10, 10)) # a grid of (10, 10) neurons in a group
hh = bp.dyn.HH(size=(5, 4, 2)) # a column of (5, 4, 2) neurons in a group
Generally speaking, there are two types of arguments can be set by users:
parameters: the model parameters, like
gNa
refers to the maximum conductance of sodium channel in thebrainpy.dyn.HH
model.variables: the model variables, like
V
refers to the membrane potential of a neuron model.
In default, model parameters are homogeneous, which are just scalar values.
hh = bp.dyn.HH(5) # there are five neurons in this group
hh.gNa
120.0
However, neuron models support heterogeneous parameters when performing computations in a neuron group. One can initialize heterogeneous parameters by several ways.
1. Tensor
Users can directly provide a tensor as the parameter.
hh = bp.dyn.HH(5, gNa=bm.random.uniform(110, 130, size=5))
hh.gNa
JaxArray([114.53795, 127.13995, 119.036 , 110.91665, 117.91266], dtype=float32)
2. Initializer
BrainPy provides wonderful supports on initializations. One can provide an initializer to the parameter to instruct the model initialize heterogeneous parameters.
hh = bp.dyn.HH(5, ENa=bp.init.OneInit(50.))
hh.ENa
JaxArray([50., 50., 50., 50., 50.], dtype=float32)
3. Callable function
You can also directly provide a callable function which receive a shape
argument.
hh = bp.dyn.HH(5, ENa=lambda shape: bm.random.uniform(40, 60, shape))
hh.ENa
JaxArray([52.201824, 52.322166, 44.033783, 47.943596, 54.985268], dtype=float32)
Here, let’s see how the heterogeneous parameters influence our model simulation.
# we create 3 neurons in a group. Each neuron has a unique "gNa"
model = bp.dyn.HH(3, gNa=bp.init.Uniform(min_val=100, max_val=140))
runner = bp.dyn.DSRunner(model, monitors=['V'], inputs=['input', 5.])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 2], show=True)

Similarly, the setting of the initial values of a variable can also be realized through the above three ways: Tensor, Initializer, and Callable function. For example,
hh = bp.dyn.HH(
3,
V_initializer=bp.init.Uniform(-80., -60.), # Initializer
m_initializer=lambda shape: bm.random.random(shape), # function
h_initializer=bm.random.random(3), # Tensor
)
print('V: ', hh.V)
print('m: ', hh.m)
print('h: ', hh.h)
V: Variable([-77.707954, -73.94804 , -69.09014 ], dtype=float32)
m: Variable([0.4219371, 0.5383264, 0.8984035], dtype=float32)
h: Variable([0.61493886, 0.81473637, 0.3291837 ], dtype=float32)
Initializing a synapse model#
Initializing a synapse model needs to provide its pre-synaptic group (pre
), post-synaptic group (post
) and the connection method between them (conn
). The below is an example to create an Exponential synapse model:
neu = bp.dyn.LIF(10)
# here we create a synaptic projection within a population
syn = bp.dyn.ExpCUBA(pre=neu, post=neu, conn=bp.conn.All2All())
BrainPy’s build-in synapse models support heterogeneous synaptic weights and delay steps by using Tensor, Initializer and Callable function. For example,
syn = bp.dyn.ExpCUBA(neu, neu, bp.conn.FixedProb(prob=0.1),
g_max=bp.init.Uniform(min_val=0.1, max_val=1.),
delay_step=lambda shape: bm.random.randint(10, 30, shape))
syn.g_max
JaxArray([0.9790364 , 0.18719104, 0.84017825, 0.31185275, 0.38157037,
0.80953383, 0.61926776, 0.73845625, 0.9679548 , 0.385096 ,
0.91454816], dtype=float32)
syn.delay_step
JaxArray([18, 19, 15, 21, 17, 24, 10, 27, 12, 20], dtype=int32)
However, in BrainPy, the built-in synapse models only support homogenous synaptic parameters, like the time constant \(\tau\). Users can customize their synaptic models when they want heterogeneous synatic parameters.
Similar, the synaptic variables can be initialized heterogeneously by using Tensor, Initializer, and Callable functions.
Change model parameters during simulation#
In BrainPy, all the dynamically changed variables (no matter it is changed inside or outside of a jitted function) should be marked as brainpy.math.Variable
. BrainPy’s built-in models also support modifying model parameters during simulation.
For example, if you want to fix the gNa
in the first 100 ms simulation, and then try to decrease its value in the following simulations. In this case, we can provide the gNa
as an instance of brainpy.math.Variable
when initializing the model.
hh = bp.dyn.HH(5, gNa=bm.Variable(bm.asarray([120.])))
runner = bp.dyn.DSRunner(hh, monitors=['V'], inputs=['input', 5.])
# the first running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

# change the gNa first
hh.gNa[:] = 100.
# the second running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

Examples of using built-in models#
Here we show users how to simulate a famous neuron models: The Morris-Lecar neuron model, which is a two-dimensional “reduced” excitation model applicable to systems having two non-inactivating voltage-sensitive conductances.
group = bp.dyn.MorrisLecar(1)
Then users can utilize various tools provided by BrainPy to easily simulate the Morris-Lecar neuron model. Here we are not going to dive into details so please read the corresponding tutorials if you want to learn more.
runner = bp.dyn.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.))
runner.run(1000)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)

Next we will also give users an intuitive understanding about building a network composed of different neurons and synapses model. Users can simply initialize these models as below and pass into brainpy.dyn.Network
.
neu1 = bp.dyn.HH(1)
neu2 = bp.dyn.HH(1)
syn1 = bp.dyn.AMPA(neu1, neu2, bp.connect.All2All())
net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
By selecting proper runner, users can simulate the network efficiently and plot the simulation results.
runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g'])
runner.run(150.)
import matplotlib.pyplot as plt
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
plt.legend()
fig.add_subplot(gs[1, 0])
plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
plt.legend()
plt.show()

Building Neuron Models#
The previous section shows all available models users can utilize by simply instantiating the abstract model. In following sections we will dive into details to illustrate how to build a neuron model with brainpy.dyn.NeuGroup
. Neurons are the most basic components in neural dynamics simulation. In BrainPy, brainpy.dyn.NeuGroup
is used for neuron modeling.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
brainpy.dyn.NeuGroup
#
Generally, any neuron model can evolve continuously or discontinuously. Discontinuous evolution may be triggered by events, such as the reset of membrane potential. Moreover, it is common in a neural system that a dynamical system has different states, such as the excitable or refractory state in a leaky integrate-and-fire (LIF) model. In this section, we will use two examples to illustrate how to capture these complexity in neuron modeling.
Defining a neuron model in BrainPy is simple. You just need to inherit from brainpy.dyn.NeuGroup
, and satisfy the following two requirements:
Providing the
size
of the neural group in the constructor when initialize a new neural group class.size
can be a integer referring to the number of neurons or a tuple/list of integers referring to the geometry of the neural group in different dimensions. Acoording to the provided groupsize
, NeuroGroup will automatically calculate the total numbernum
of neurons in this group.Creating an
update(_t, dt)
function. Update function provides the rule how the neuron states are evolved from the current time \(\mathrm{\_t}\) to the next time \(\mathrm{\_t + \_dt}\).
In the following part, a Hodgkin-Huxley (HH) model is used as an example for illustration.
Hodgkin–Huxley Model#
The Hodgkin-Huxley (HH) model is a continuous-time dynamical system. It is one of the most successful mathematical models of a complex biological process that has ever been formulated. Changes of the membrane potential influence the conductances of different channels, elaborately modeling the neural activities in biological systems. Mathematically, the model is given by:
where \(V\) is the membrane potential, \(C_m\) is the membrane capacitance per unit area, \(E_K\) and \(E_{Na}\) are the potassium and sodium reversal potentials, respectively, \(E_l\) is the leak reversal potential, \(\bar{g}_K\) and \(\bar{g}_{Na}\) are the potassium and sodium conductances per unit area, respectively, and \(\bar{g}_l\) is the leak conductance per unit area. Because the potassium and sodium channels are voltage-sensitive, according to the biological experiments, \(m\), \(n\) and \(h\) are used to simulate the activation of the channels. Speficially, \(n\) measures the activatio of potassium channels, and \(m\) and \(h\) measures the activation and inactivation of sodium channels, respectively. \(\alpha_{x}\) and \(\beta_{x}\) are rate constants for the ion channel x and depend exclusively on the membrane potential.
To implement the HH model, variables should be specified. According to the above equations, the following five state variables change with respect to time:
V
: the membrane potentialm
: the activation of sodium channelsh
: the inactivation of sodium channelsn
: the activation of potassium channelsinput
: the external/synaptic input
Besides, the spiking state and the last spiking time can also be recorded for statistic analysis:
spike
: whether a spike is producedt_last_spike
: the last spiking time
Based on these state variables, the HH model can be implemented as below.
class HH(bp.dyn.NeuGroup):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, **kwargs):
# providing the group "size" information
super(HH, self).__init__(size=size, **kwargs)
# initialize parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.gNa = gNa
self.gK = gK
self.gL = gL
self.C = C
self.V_th = V_th
# initialize variables
self.V = bm.Variable(bm.random.randn(self.num) - 70.)
self.m = bm.Variable(0.5 * bm.ones(self.num))
self.h = bm.Variable(0.6 * bm.ones(self.num))
self.n = bm.Variable(0.32 * bm.ones(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
# integral functions
self.int_V = bp.odeint(f=self.dV, method='exp_auto')
self.int_m = bp.odeint(f=self.dm, method='exp_auto')
self.int_h = bp.odeint(f=self.dh, method='exp_auto')
self.int_n = bp.odeint(f=self.dn, method='exp_auto')
def dV(self, V, t, m, h, n, Iext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
return dVdt
def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt
def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt
def dn(self, n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
def update(self, _t, _dt):
# compute V, m, h, n
V = self.int_V(self.V, _t, self.m, self.h, self.n, self.input, dt=_dt)
self.h.value = self.int_h(self.h, _t, self.V, dt=_dt)
self.m.value = self.int_m(self.m, _t, self.V, dt=_dt)
self.n.value = self.int_n(self.n, _t, self.V, dt=_dt)
# update the spiking state and the last spiking time
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
# update V
self.V.value = V
# reset the external input
self.input[:] = 0.
When defining the HH model, equation (1) is accomplished by brainpy.odeint as an ODEIntegrator. The details are contained in the Numerical Solvers for ODEs tutorial.
The variables, which will be updated during dynamics simulation, should be packed as brainpy.math.Variable
and thus can be processed by JIT compliers to accelerate simulation.
In the following part, a leaky integrate-and-fire (LIF) model is introduced as another example for illustration.
Leaky Integrate-and-Fire Model#
The LIF model is the classical neuron model which contains a continuous process and a discontinous spike reset operation. Formally, it is given by:
where \(V\) is the membrane potential, \(V_{rest}\) is the rest membrane potential, \(V_{th}\) is the spike threshold, \(\tau_m\) is the time constant, \(\tau_{ref}\) is the refractory time period, and \(I\) is the time-variant synaptic inputs.
The above two equations model the continuous change and the spiking of neurons, respectively. Moreover, it has multiple states: subthreshold
state, and spiking
or refractory
state. The membrane potential \(V\) is integrated according to equation (1) when it is below \(V_{th}\). Once \(V\) reaches the threshold \(V_{th}\), according to equation (2), \(V\) is reaet to \(V_{rest}\), and the neuron enters the refractory period where the membrane potential \(V\) will remain constant in the following \(\tau_{ref}\) ms.
The neuronal variables, like the membrane potential and external input, can be captured by the following two variables:
V
: the membrane potentialinput
: the external/synaptic input
In order to define the different states of a LIF neuron, we define additional variables:
spike
: whether a spike is producedrefractory
: whether the neuron is in the refractory periodt_last_spike
: the last spiking time
Based on these state variables, the LIF model can be implemented as below.
class LIF(bp.dyn.
NeuGroup):
def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., R=1., tau=10., t_ref=5., **kwargs):
super(LIF, self).__init__(size=size, **kwargs)
# initialize parameters
self.V_rest = V_rest
self.V_reset = V_reset
self.V_th = V_th
self.R = R
self.tau = tau
self.t_ref = t_ref
# initialize variables
self.V = bm.Variable(bm.random.randn(self.num) + V_reset)
self.input = bm.Variable(bm.zeros(self.num))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# integral function
self.integral = bp.odeint(f=self.derivative, method='exp_auto')
def derivative(self, V, t, Iext):
dvdt = (-V + self.V_rest + self.R * Iext) / self.tau
return dvdt
def update(self, _t, _dt):
# Whether the neurons are in the refractory period
refractory = (_t - self.t_last_spike) <= self.t_ref
# compute the membrane potential
V = self.integral(self.V, _t, self.input, dt=_dt)
# computed membrane potential is valid only when the neuron is not in the refractory period
V = bm.where(refractory, self.V, V)
# update the spiking state
spike = self.V_th <= V
self.spike.value = spike
# update the last spiking time
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
# update the membrane potential and reset spiked neurons
self.V.value = bm.where(spike, self.V_reset, V)
# update the refractory state
self.refractory.value = bm.logical_or(refractory, spike)
# reset the external input
self.input[:] = 0.
In above, the discontinous resetting is implemented with brainpy.math.where
operation.
Instantiation and running#
Here, let’s try to instantiate a HH
neuron group:
neu = HH(10)
in which a neural group containing 10 HH neurons is generated.
The details of the model simulation will be expanded in the Runners section. In brief, running any dynamical system instance should be accomplished with a runner, such like brianpy.DSRunner
and brainpy.ReportRunner
. The variables to be monitored and the input crrents to be applied in the simulation can be provided when initializing the runner. The details are accessible in Monitors and Inputs.
runner = bp.dyn.DSRunner(
neu,
monitors=['V'],
inputs=('input', 22.) # constant external inputs of 22 mA to all neurons
)
Then the simulation can be performed with a given time period, and the simulation result can be visualized:
runner.run(200) # the running time is 200 ms
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

A LIF neural group can be instantiated and applied in simulation in a similar way:
group = LIF(10)
runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 22.), jit=True)
runner.run(200)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

Building Synapse Models#
Synaptic computation is the core of brain dynamics programming. This is beacuse in a real project most of the simulation time spends on the computation of synapses. In order to achieve efficient synaptic computation, BrainPy provides many useful supports. Here, we are going to explore the details of these supports.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
Synapse Models in Math#
Before we talk about the implementation of synapses in BrainPy, it’s better to understand the targets (synapse models) we are going to implement. For different illustration purposes, we are going to implement two synapse models: exponential synapse model and AMPA synapse model.
1. The exponential synapse model#
The exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant \(\tau_{decay}\). Its dynamics is given by:
where \(g\) is the synaptic state, \(t^{k}\) is the spike time of the pre-synaptic neuron, and \(D\) is the synaptic delay.
Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form:
where \(E\) is the reversal potential of the synapse, \(V\) is the post-synaptic membrane potential, \(g_{max}\) is the maximum synaptic conductance.
2. The AMPA synapse model#
A classical model of AMPA synapse is to use the Markov process to model ion channel switch. Here \(g\) represents the probability of channel opening, \(1-g\) represents the probability of ion channel closing, and \(\alpha\) and \(\beta\) are the transition probability. Specifically, its formula is given by
where \(\alpha [T]\) denotes the transition probability from state \((1-g)\) to state \((g)\); and \(\beta\) represents the transition probability of the other direction. \(\alpha\) is the binding constant. \(\beta\) is the unbinding constant. \([T]\) is the neurotransmitter concentration, and has the duration of 0.5 ms.
Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
where \(g_{max}\) is the maximum conductance, and \(E\) is the reverse potential.
Synapse Models in Silicon#
The implementation of synapse models is accomplished by brainpy.dyn.TwoEndConn
interface. In this section, we talk about what supports are provided for the implementation of synapse models in silicon.
1. brainpy.dyn.TwoEndConn
#
In BrainPy, brainpy.dyn.TwoEndConn
is used to model two-end synaptic computations.
To define a synapse model, two requirements should be satisfied:
1. Constructor function __init__()
, in which three key arguments are needed.
pre
: the pre-synaptic neural group. It should be an instance ofbrainpy.dyn.NeuGroup
.post
: the post-synaptic neural group. It should be an instance ofbrainpy.dyn.NeuGroup
.conn
(optional): the connection type between these two groups. BrainPy has provided abundant connection types that are described in details in the Synaptic Connections.
2. Update function update(_t, _dt)
describes the updating rule from the current time \(\mathrm{\_t}\) to the next time \(\mathrm{\_t + \_dt}\).
2. Variable delays#
As seen in the above two synapse models, synaptic computations are usually involved with variable delays. A delay time (typically 0.3–0.5 ms) is usually required for a neurotransmitter to be released from a presynaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane.
BrainPy provides several kinds of delay variables for users, including:
brainpy.math.LengthDelay
: a delay variable which defines a constant steps for delay.brainpy.math.TimeDelay
: a delay variable which defines a constant time length for delay.
Assume here we need a delay variable which has 1 ms delay. If the numerical integration precision dt
is 0.1 ms, then we can create a brainpy.math.LengthDelay
which has 10 delay time steps.
target_data_to_delay = bm.Variable(bm.zeros(10))
example_delay = bm.LengthDelay(target_data_to_delay,
delay_len=10) # delay 10 steps
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
example_delay(5) # call the delay data at 5 delay step
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(10) # call the delay data at 10 delay step
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Alternatively, we can create an instance of brainpy.math.TimeDelay
, which use time t
as the index to retrieve the delay data.
t0 = 0.
example_delay = bm.TimeDelay(target_data_to_delay,
delay_len=1.0, t0=t0) # delay 1.0 ms
example_delay(t0 - 1.0) # the delay data at t-1. ms
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(t0 - 0.5) # the delay data at t-0.5 ms
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
3. Synaptic connections#
Synaptic computations usually need to create connection between groups. BrainPy provides many wonderful supports to construct synaptic connections. Simply speaking, brainpy.conn.Connector
can create various data sturctures you want through the require()
function. Take the random connection brainpy.conn.FixedProb
which will be used in follows as the example,
example_conn = bp.conn.FixedProb(0.2)(pre_size=(5,), post_size=(8, ))
we can require the connection matrix (has the shape of (num_pre, num_post)
:
example_conn.require('conn_mat')
JaxArray([[False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, False],
[False, False, False, False, False, True, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False]], dtype=bool)
we can also require the connected indices of pre-synaptic neurons (pre_ids
) and post-synaptic neurons (post_ids
):
example_conn.require('pre_ids', 'post_ids')
(JaxArray([0, 0, 1, 1, 2, 2, 3, 4, 4, 4, 4], dtype=uint32),
JaxArray([1, 4, 4, 5, 2, 3, 6, 1, 5, 6, 7], dtype=uint32))
Or, we can require the connection structure of pre2post
which stores the information how does each pre-synaptic neuron connect to post-synaptic neurons:
example_conn.require('pre2post')
(JaxArray([0, 3, 4, 1, 0, 2, 7], dtype=uint32),
JaxArray([0, 3, 3, 4, 7, 7], dtype=uint32))
More details of the connection structures please see the tutorial of Synaptic Connections.
Achieving efficient synaptic computation is difficult#
Synaptic computations usually need to transform the data of the pre-synaptic dimension into the data of the post-synaptic dimension, or the data with the shape of the synapse number. There does not exist a universal computation method that are efficient in all cases. Usually, we need different ways for different connection situations to achieve efficient synaptic computation. In the next two sections, we will talk about how to define efficient synaptic models when your connections are sparse or dense.
Before we start, we need to define some useful helper functions to define and show synapse models. Then, we will highlight the key differences of model difinition when using different synaptic connections.
# Basic Model to define the exponential synapse model. This class
# defines the basic parameters, variables, and integral functions.
class BaseExpSyn(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., method='exp_auto'):
super(BaseExpSyn, self).__init__(pre=pre, post=post, conn=conn)
# check whether the pre group has the needed attribute: "spike"
self.check_pre_attrs('spike')
# check whether the post group has the needed attribute: "input" and "V"
self.check_post_attrs('input', 'V')
# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
# use "LengthDelay" to store the spikes of the pre-synaptic neuron group
self.delay_step = int(delay/bm.get_dt())
self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)
# integral function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
# Basic Model to define the AMPA synapse model. This class
# defines the basic parameters, variables, and integral functions.
class BaseAMPASyn(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.42, E=0., alpha=0.98,
beta=0.18, T=0.5, T_duration=0.5, method='exp_auto'):
super(BaseAMPASyn, self).__init__(pre=pre, post=post, conn=conn)
# check whether the pre group has the needed attribute: "spike"
self.check_pre_attrs('spike')
# check whether the post group has the needed attribute: "input" and "V"
self.check_post_attrs('input', 'V')
# parameters
self.delay = delay
self.g_max = g_max
self.E = E
self.alpha = alpha
self.beta = beta
self.T = T
self.T_duration = T_duration
# use "LengthDelay" to store the spikes of the pre-synaptic neuron group
self.delay_step = int(delay/bm.get_dt())
self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)
# store the arrival time of the pre-synaptic spikes
self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)
# integral function
self.integral = bp.odeint(self.derivative, method=method)
def derivative(self, g, t, TT):
dg = self.alpha * TT * (1 - g) - self.beta * g
return dg
# for more details of how to run a simulation please see the tutorials in "Dynamics Simulation"
def show_syn_model(model):
pre = bp.models.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
post = bp.models.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
syn = model(pre, post, conn=bp.conn.One2One())
net = bp.dyn.Network(pre=pre, post=post, syn=syn)
runner = bp.DSRunner(net,
monitors=['pre.V', 'post.V', 'syn.g'],
inputs=['pre.input', 22.])
runner.run(100.)
fig, gs = bp.visualize.get_figure(1, 2, 3, 4)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['syn.g'], legend='syn.g')
fig.add_subplot(gs[0, 1])
bp.visualize.line_plot(runner.mon.ts, runner.mon['pre.V'], legend='pre.V')
bp.visualize.line_plot(runner.mon.ts, runner.mon['post.V'], legend='post.V', show=True)
Computation with Dense Connections#
Matrix-based synaptic computation is straightforward. Especially, when your models are connected densely, using matrix is highly efficient.
conn_mat
#
Assume two neuron groups are connected through a fixed probability of 0.7.
conn = bp.conn.FixedProb(0.7)(pre_size=6, post_size=8)
Then you can create the connection matrix though conn.require("conn_mat")
:
conn.require('conn_mat')
JaxArray([[False, True, False, True, True, False, True, True],
[ True, True, True, False, True, True, True, True],
[False, True, True, True, True, True, True, True],
[ True, True, True, False, True, True, True, True],
[False, True, False, True, True, True, True, False],
[ True, False, True, True, False, True, False, True]], dtype=bool)
conn_mat
has the shape of (num_pre, num_post)
. Therefore, transforming the data with the pre-synaptic dimension into the date of the post-synaptic dimension is very easy. You just need make a matrix multiplication: brainpy.math.dot(pre_values, conn_mat)
(\(\mathbb{R}^\mathrm{num\_pre} @ \mathbb{R}^\mathrm{(num\_pre, num\_post)} \to \mathbb{R}^\mathrm{num\_post}\)).
With the synaptic connection of conn_mat
in above, we can define the exponential synapse model as the follows. It’s worthy to note that the evolution of states ouput onto the same post-synaptic neurons in exponential synapses can be superposed. This means we can declare the synapse variables with the shape of post-synaptic group, rather than the number of the total synapses.
class ExpConnMat(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpConnMat, self).__init__(*args, **kwargs)
# connection matrix
self.conn_mat = self.conn.require('conn_mat')
# synapse gating variable
# -------
# NOTE: Here the synapse number is the same with
# the post-synaptic neuron number. This is
# different from the AMPA synapse.
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, _t, _dt):
# pull the delayed pre spikes for computation
delayed_spike = self.pre_spike(self.delay_step)
# push the latest pre spikes into the bottom
self.pre_spike.update(self.pre.spike)
# integrate the synapse state
self.g.value = self.integral(self.g, _t, dt=_dt)
# update synapse states according to the pre spikes
post_sps = bm.dot(delayed_spike, self.conn_mat)
self.g += post_sps
# get the post-synaptic current
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpConnMat)

We can also use conn_mat
to define an AMPA synapse model. Note here the shape of the synapse variable \(g\) is (num_pre, num_post)
, rather than self.post.num
in the above exponential synapse model. This is because the synaptic states of AMPA model can not be superposed.
class AMPAConnMat(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPAConnMat, self).__init__(*args, **kwargs)
# connection matrix
self.conn_mat = self.conn.require('conn_mat')
# synapse gating variable
# -------
# NOTE: Here the synapse shape is (num_pre, num_post),
# in contrast to the ExpConnMat
self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))
def update(self, _t, _dt):
# pull the delayed pre spikes for computation
delayed_spike = self.pre_spike(self.delay_step)
# push the latest pre spikes into the bottom
self.pre_spike.update(self.pre.spike)
# get the time of pre spikes arrive at the post synapse
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
# get the neurotransmitter concentration at the current time
TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
# integrate the synapse state
TT = TT.reshape((-1, 1)) * self.conn_mat # NOTE: only keep the concentrations
# on the invalid connections
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
# get the post-synaptic current
g_post = self.g.sum(axis=0)
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAConnMat)

Special connections#
Sometimes, we can define some synapse models with special connection types, such as all-to-all connection, or one-to-one connection. For these special situations, even the connection information can be ignored, i.e., we do not need conn_mat
or other structures any more.
Assume the pre-synaptic group connects to the post-synaptic group with a all-to-all fashion. Then, exponential synapse model can be defined as,
class ExpAll2All(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpAll2All, self).__init__(*args, **kwargs)
# synapse gating variable
# -------
# The synapse variable has the shape of the post-synaptic group
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, _t, _dt):
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.g.value = self.integral(self.g, _t, dt=_dt)
self.g += delayed_spike.sum() # NOTE: HERE is the difference
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpAll2All)

Similarly, the AMPA synapse model can be defined as
class AMPAAll2All(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPAAll2All, self).__init__(*args, **kwargs)
# synapse gating variable
# -------
# The synapse variable has the shape of the post-synaptic group
self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))
def update(self, _t, _dt):
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
TT = TT.reshape((-1, 1)) # NOTE: here is the difference
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
g_post = self.g.sum(axis=0) # NOTE: here is also different
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAAll2All)

Actually, the synaptic computation with these special connections can be very efficient! A concrete example please see a decision making spiking model in BrainPy-Examples. This implementation achievew at least four times acceleration comparing to the implementation in other frameworks.
Computation with Sparse Connections#
However, in the real neural system, the neurons are connected sparsely in essence.
Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need \(10^8\) floats to save the synaptic state, and at each update step, you need do computation on \(10^8\) floats. Actually, the number of synapses you really connect is only \(10^7\). See, there is a huge memory waste and computing resource inefficiency. Moreover, at the given time \(\mathrm{\_t}\), the number of pre-synaptic neurons in the spiking state is small on average. This means we have made many useless computations when defining synaptic computations with matrix-based connections (zeros dot connection matrix results in zeros).
Therefore, we need new ways to define synapse models. Specifically, we use vectors to store the connected neuron indices, like the pre_ids
and post_ids
(see Synaptic Connections).
In the below, we assume you have learned the synaptic connection types detailed in the tutorial of Synaptic Connections.
The pre2post
operator#
A notable difference of brain dynamics models from the deep learning is that they are sparse and event-driven. In order to support this significant different kind of computations, BrainPy has built many useful operators. In this section, we talk about a set of operators needed in pre2post
computations.
Note before we have said that exponential synapse model can make computations at the dimension of the post-synaptic group. Therefore, we can directly transform the pre-synaptic data into the data of the post-synaptic shape. brainpy.math.pre2post_event_sum(events, pre2post, post_num, values) can satisfy your requirements. This operator needs the synaptic structure of pre2post
(a tuple contains the post_ids
and idnptr
of pre-synaptic neurons).
If values
is a scalar, pre2post_event_sum
is equivalent to:
post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values
If values
is a vector, pre2post_event_sum
is equivalent to:
post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values[j]
With this operator, exponential synapse model can be defined as:
class ExpSparse(BaseExpSyn):
def __init__(self, *args, **kwargs):
super(ExpSparse, self).__init__(*args, **kwargs)
# connections
self.pre2post = self.conn.require('pre2post')
# synapse variable
self.g = bm.Variable(bm.zeros(self.post.num))
def update(self, _t, _dt):
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
self.g.value = self.integral(self.g, _t, dt=_dt)
# NOTE: update synapse states according to the pre spikes
post_sps = bm.pre2post_event_sum(delayed_spike, self.pre2post, self.post.num, 1.)
self.g += post_sps
self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpSparse)

This model will be very efficient when your synapses are connected sparsely.
The pre2syn
and syn2post
operators#
However, for AMPA synapse model, the pre-synaptic values can not be directly transformed into the post-synaptic dimensional data. Therefore, we need to first change the pre data into the data of the synapse dimension, then transform the synapse-dimensional data into the post-dimensional data.
Therefore, the core problem of synaptic computation is how to convert values among different shape of tensors. Specifically, in the above AMPA synapse model, we have three kinds of tensor shapes (see the following figure): tensors with the dimension of pre-synaptic group, tensors of the dimension of post-synaptic group, and tensors with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.
Here BrainPy provides two operators brainpy.math.pre2syn(pre_values, pre_ids) and brainpy.math.syn2post(syn_values, post_ids, post_num) to convert vectors among different dimensions.
brainpy.math.pre2syn()
receives two arguments: “pre_values” (the variable of the pre-synaptic dimension) and “pre_ids” (the connected pre-synaptic neuron index).brainpy.math.syn2post()
receives three arguments: “syn_values” (the variable with the synaptic size), “post_ids” (the connected post-synaptic neuron index) and “post_num” (the number of the post-synaptic neurons).
Based on these two operators, we can define the AMPA synapse model as:
class AMPASparse(BaseAMPASyn):
def __init__(self, *args, **kwargs):
super(AMPASparse, self).__init__(*args, **kwargs)
# connection matrix
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
# synapse gating variable
# -------
# NOTE: Here the synapse shape is (num_syn,)
self.g = bm.Variable(bm.zeros(len(self.pre_ids)))
def update(self, _t, _dt):
delayed_spike = self.pre_spike(self.delay_step)
self.pre_spike.update(self.pre.spike)
# get the time of pre spikes arrive at the post synapse
self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
# get the arrival time with the synapse dimension
arrival_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
# get the neurotransmitter concentration at the current time
TT = ((_t - arrival_times) < self.T_duration) * self.T
# integrate the synapse state
self.g.value = self.integral(self.g, _t, TT, dt=_dt)
# get the post-synaptic current
g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPASparse)

We hope this tutorial will help your synapse models be defined efficiently.
Building Network Models#
In previous sections, it has been illustrated how to define neuron models by brainpy.dyn.NeuGroup
and synapse models by brainpy.dyn.TwoEndConn
. This section will introduce brainpy.dyn.Network
, which is the base class used to build network models.
In essence, brainpy.dyn.Network is a container, whose function is to compose the individual elements. It is a subclass of a more general class: brainpy.dyn.Container.
In below, we take an excitation-inhibition (E-I) balanced network model as an example to illustrate how to compose the LIF neurons and Exponential synapses defined in previous tutorials to build a network.
import brainpy as bp
bp.math.set_platform('cpu')
Excitation-Inhibition (E-I) Balanced Network#
The E-I balanced network was first proposed to explain the irregular firing patterns of cortical neurons and comfirmed by experimental data. The network [1] we are going to implement consists of excitatory (E) neurons and inhibitory (I) neurons, the ratio of which is about 4 : 1. The biggest difference between excitatory and inhibitory neurons is the reversal potential - the reversal potential of inhibitory neurons is much lower than that of excitatory neurons. Besides, the membrane time constant of inhibitory neurons is longer than that of excitatory neurons, which indicates that inhibitory neurons have slower dynamics.
[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98.
# BrianPy has some built-in conanical neuron and synapse models
LIF = bp.dyn.neurons.LIF
ExpCOBA = bp.dyn.synapses.ExpCOBA
Two ways to define network models#
There are several ways to define a Network model.
1. Defining a network as a class#
The first way to define a network model is like follows.
class EINet(bp.dyn.Network):
def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
super(EINet, self).__init__(**kwargs)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(num_exc, **pars, method=method)
I = LIF(num_inh, **pars, method=method)
E.V.value = bp.math.random.randn(num_exc) * 2 - 55.
I.V.value = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
w_e = 0.6 # excitatory synaptic weight
w_i = 6.7 # inhibitory synaptic weight
E_pars = dict(E=0., g_max=w_e, tau=5.)
I_pars = dict(E=-80., g_max=w_i, tau=10.)
# Neurons connect to each other randomly with a connection probability of 2%
self.E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
self.E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
self.I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
self.I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
self.E = E
self.I = I
In an instance of brainpy.dyn.Network
, all self.
accessed elements can be gathered by the .child_ds()
function automatically.
EINet(8, 2).child_ds()
{'ExpCOBA0': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x1076db580>,
'ExpCOBA1': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x1584c3d30>,
'ExpCOBA2': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x158496a00>,
'ExpCOBA3': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x15880ce80>,
'LIF0': <brainpy.dyn.neurons.IF_models.LIF at 0x1583fd100>,
'LIF1': <brainpy.dyn.neurons.IF_models.LIF at 0x1583f6d30>,
'ConstantDelay0': <brainpy.dyn.base.ConstantDelay at 0x1584ba1c0>,
'ConstantDelay1': <brainpy.dyn.base.ConstantDelay at 0x1584c3d00>,
'ConstantDelay2': <brainpy.dyn.base.ConstantDelay at 0x158496370>,
'ConstantDelay3': <brainpy.dyn.base.ConstantDelay at 0x15880c340>}
Note in the above EINet
, we do not define the update()
function. This is because any subclass of brainpy.dyn.Network
has a default update function, in which it automatically gathers the elements defined in this network and sequentially runs the update function of each element.
If you have some special operations in your network, you can override the update function by yourself. Here is a simple example.
class ExampleToOverrideUpdate(EINet):
def update(self, _t, _dt):
for node in self.child_ds().values():
node.update(_t, _dt)
Let’s try to simulate our defined EINet
model.
net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method
runner = bp.dyn.DSRunner(net,
monitors=['E.spike', 'I.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
title='Spikes of Inhibitory Neurons', show=True)
Used time 0.35350608825683594 s


2. Instantiating a network directly#
Another way to instantiate a network model is directly pass the elements into the constructor of brainpy.Network
. It receives *args
and **kwargs
arguments.
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(3200, **pars)
I = LIF(800, **pars)
E.V.value = bp.math.random.randn(E.num) * 2 - 55.
I.V.value = bp.math.random.randn(I.num) * 2 - 55.
# synapses
E_pars = dict(E=0., g_max=0.6, tau=5.)
I_pars = dict(E=-80., g_max=6.7, tau=10.)
E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), **E_pars)
E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), **E_pars)
I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), **I_pars)
I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), **I_pars)
# Network
net2 = bp.dyn.Network(E2E, E2I, I2E, I2I, exc_group=E, inh_group=I)
All elements are passed as **kwargs
argument can be accessed by the provided keys. This will affect the following dynamics simualtion and will be discussed in greater detail in tutorial of Runners.
net2.exc_group
<brainpy.dyn.neurons.IF_models.LIF at 0x159470c10>
net2.inh_group
<brainpy.dyn.neurons.IF_models.LIF at 0x159470d30>
After construction, the simulation goes the same way:
runner = bp.dyn.DSRunner(net2,
monitors=['exc_group.spike', 'inh_group.spike'],
inputs=[('exc_group.input', 20.), ('inh_group.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['exc_group.spike'],
title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['inh_group.spike'],
title='Spikes of Inhibitory Neurons', show=True)
Used time 0.3590219020843506 s


Above are some simulation examples showing the possible application of network models. The detailed description of dynamics simulation is covered in the toolboxes, where the use of runners, monitors, and inputs will be expatiated.
Building General Dynamical Systems#
The previous sections have shown how to build neuron models, synapse models, and network models. In fact, these brain objects all inherit the base class brainpy.dyn.DynamicalSystem, brainpy.dyn.DynamicalSystem
is the universal language to define dynamical models in BrainPy.
To begin with, let’s make a rief summary of previous dynamic models and give the definition of a dynamical system.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
What is a dynamical system?#
Looking back to the neuron and synapse models defined in the previous sections, they share a common feature that they all contain some variables that change over time. Because of these variables, the models become ‘dynamic’ and behave differently at different times.
Actually, a dynamical system is defined as a system with time-dependent states. These time-dependent states are displayed as variables in the previous models.
Mathematically, the change of a state \(X\) can be expressed as
where \(X\) is the state of the system, \(t\) is the time, and \(f\) is a function describing the time dependence of the state.
Alternatively, the evolution of the system over time can be given by
where \(dt\) is the time step and \(F\) is the evolution rule to update the system’s state.
Customizing your dynamical systems#
According to the mathematical expression of a dynamical system, any subclass of brainpy.dyn.DynamicalSystem
must implement an updating rule in the update(self, t, dt)
function.
To define a dynamical system, the following requirements should be satisfied:
Inherit from
brainpy.dyn.DynamicalSystem
.Implement the
update(self, t, dt)
function.When defining variables, they should be declared as
brainpy.math.Variable
.When updating the variables, it should be realized by in-place operations.
Below is a simple example of a dynamical system.
class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
def __init__(self, a=0.8, b=0.7, tau=12.5, **kwargs):
super(FitzHughNagumoModel, self).__init__(**kwargs)
# parameters
self.a = a
self.b = b
self.tau = tau
# variables should be packed by brainpy.math.Variable
self.v = bm.Variable([0.])
self.w = bm.Variable([0.])
self.I = bm.Variable([0.])
def update(self, _t, _dt):
# _t : the current time, the system keyword
# _dt : the time step, the system keyword
# in-place update
self.w += (self.v + self.a - self.b * self.w) / self.tau * _dt
self.v += (self.v - self.v ** 3 / 3 - self.w + self.I) * _dt
self.I[:] = 0.
Here, we have defined a dynamical system called FitzHugh–Nagumo neuron model, whose dynamics is given by:
By using the Euler method, this system can be updated by the following rule:
Advantages of using brainpy.dyn.DynamicalSystem
#
There are several advantages of defining a dynamical system as brainpy.dyn.DynamicalSystem
.
1. A systematic naming system.#
First, every instance of DynamicalSystem
has its unique name.
fhn = FitzHughNagumoModel()
fhn.name # name for "fhn" instance
'FitzHughNagumoModel1'
Every instance has its unique name:
for _ in range(3):
print(FitzHughNagumoModel().name)
FitzHughNagumoModel2
FitzHughNagumoModel3
FitzHughNagumoModel4
Users can also specify the name of a dynamic system:
fhn2 = FitzHughNagumoModel(name='X')
fhn2.name
'X'
# same name will cause error
try:
FitzHughNagumoModel(name='X')
except bp.errors.UniqueNameError as e:
print(e)
In BrainPy, each object should have a unique name. However, we detect that <__main__.FitzHughNagumoModel object at 0x000001F75163C250> has a used name "X".
If you try to run multiple trials, you may need
>>> brainpy.base.clear_name_cache()
to clear all cached names.
Second, variables, children nodes, etc. inside an instance can be easily accessed by their absolute or relative path.
# All variables can be acessed by
# 1). the absolute path
fhn2.vars()
{'X.I': Variable([0.], dtype=float32),
'X.v': Variable([0.], dtype=float32),
'X.w': Variable([0.], dtype=float32)}
# 2). or, the relative path
fhn2.vars(method='relative')
{'I': Variable([0.], dtype=float32),
'v': Variable([0.], dtype=float32),
'w': Variable([0.], dtype=float32)}
2. Convenient operations for simulation and analysis.#
Brainpy provides different runners for dynamics simulation and analyzers for dynamics analysis, both of which require the dynamic model to be Brainpy.dyn.DynamicalSystem
. For example, dynamic models can be packed by a runner for simulation:
runner = bp.dyn.DSRunner(fhn2, monitors=['v', 'w'], inputs=('I', 1.5))
runner(duration=100)
bp.visualize.line_plot(runner.mon.ts, runner.mon.v, legend='v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)

Please see Runners to know more about the operations in runners.
3. Efficient computation.#
brainpy.dyn.DynamicalSystem
is a subclass of brainpy.Base, and therefore, any instance of brainpy.dyn.DynamicalSystem
can be complied just-in-time into efficient machine codes targeting on CPUs, GPUs, and TPUs.
runner = bp.dyn.DSRunner(fhn2, monitors=['v', 'w'], inputs=('I', 1.5), jit=True)
runner(duration=100)
bp.visualize.line_plot(runner.mon.ts, runner.mon.v, legend='v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)

4. Support composable programming.#
Instances of brainpy.dyn.DynamicalSystem
can be combined at will. The combined system is also a brainpy.dyn.DynamicalSystem
and enjoys all the properties, methods, and interfaces provided by brainpy.dyn.DynamicalSystem
.
For example, if the instances are wrapped into a container, i.e. brainpy.dyn.Network
, variables and nodes can also be accessed by their absolute or relative path.
fhn_net = bp.dyn.Network(f1=fhn, f2=fhn2)
# absolute access of variables
fhn_net.vars()
{'FitzHughNagumoModel1.I': Variable([0.], dtype=float32),
'FitzHughNagumoModel1.v': Variable([0.], dtype=float32),
'FitzHughNagumoModel1.w': Variable([0.], dtype=float32),
'X.I': Variable([0.], dtype=float32),
'X.v': Variable([1.492591], dtype=float32),
'X.w': Variable([1.9365357], dtype=float32)}
# relative access of variables
fhn_net.vars(method='relative')
{'f1.I': Variable([0.], dtype=float32),
'f1.v': Variable([0.], dtype=float32),
'f1.w': Variable([0.], dtype=float32),
'f2.I': Variable([0.], dtype=float32),
'f2.v': Variable([1.492591], dtype=float32),
'f2.w': Variable([1.9365357], dtype=float32)}
# absolute access of nodes
fhn_net.nodes()
{'FitzHughNagumoModel1': <__main__.FitzHughNagumoModel at 0x1f7515a74c0>,
'X': <__main__.FitzHughNagumoModel at 0x1f75164bd90>,
'Network0': <brainpy.dyn.base.Network at 0x1f7529e70d0>}
# relative access of nodes
fhn_net.nodes(method='relative')
{'': <brainpy.dyn.base.Network at 0x1f7529e70d0>,
'f1': <__main__.FitzHughNagumoModel at 0x1f7515a74c0>,
'f2': <__main__.FitzHughNagumoModel at 0x1f75164bd90>}
runner = bp.dyn.DSRunner(fhn_net,
monitors=['f1.v', 'X.v'],
inputs=[('f1.I', 1.5), # relative access to variable "I" in 'fhn1'
('X.I', 1.0),]) # absolute access to variable "I" in 'fhn2'
runner(duration=100)
bp.visualize.line_plot(runner.mon.ts, runner.mon['f1.v'], legend='fhn1.v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon['X.v'], legend='fhn2.v', show=True)

Dynamics Training#
This tutorial shows how to train a dynamical system from data or task, and how to customize your nodes or networks.
Node Specification#
Neural networks in BrainPy are used to build dynamical systems. The brainpy.nn module provides various classes representing the nodes of a neural network. All of them are subclasses of the brainpy.nn.Node
base class.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
What is a node?#
In BrainPy, the Node
instance is the basic element to form a network model. It is a unit on a graph, connected to other nodes by edges.
In general, each Node
instance in BrainPy has four components:
Feedforward inputs
Feedback inputs
State
Output
It is worthy to note that each Node
instance may have multiple feedforward or feedback connections. However, it only has one state and one output. output
component is used in feedforward connections and feedback connections, which means the feedforward and feedback outputs are the same. However, customization of a different feedback output is also easy (see the Customization of a Node tutorial).

Each node has the following attributes:
feedforward_shapes
: the shapes of the feedforward inputs.feedback_shapes
: the shapes of the feedback inputs.output_shape
: the output shape of the node.state
: the state of the node. It can be None if the node has no state to hold.fb_output
: the feedback output of the node. It is None when no feedback connections are established to this node. Default, the value offb_output
is theforward()
function output value.
It also has several boolean attributes:
trainable
: whether the node is trainable.is_initialized
: whether the node has been initialized.
Creating a node#
A layer can be created as an instance of a brainpy.nn.Node
subclass. For example, a dense layer can be created as follows:
bp.nn.Dense(num_unit=100)
Dense(name=Dense0, forwards=None,
feedbacks=None, output=(None, 100))
This will create a dense layer with 100 units.
Of course, if you have known the shapes of the feedforward connections, you can use input_shape
.
bp.nn.Dense(num_unit=100, input_shape=128)
Dense(name=Dense1, forwards=((None, 128),),
feedbacks=None, output=(None, 100))
This create a densely connected layer which connected to another input layer with 128 dimension.
Naming a node#
For convenience, you can name a layer by specifying the name keyword argument:
bp.nn.Dense(num_unit=100, input_shape=128, name='hidden_layer')
Dense(name=hidden_layer, forwards=((None, 128),),
feedbacks=None, output=(None, 100))
Initializing parameters#
Many nodes have their parameters. We can set the parameter of a node with the following methods.
Tensors
If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:
l = bp.nn.Dense(num_unit=50, input_shape=10,
weight_initializer=bm.random.normal(0, 0.01, size=(10, 50)))
l.initialize(num_batch=1)
l.Wff.shape
(10, 50)
Callable function
If a callable function (which receives a shape
argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:
def init(shape):
return bm.random.random(shape)
l = bp.nn.Dense(num_unit=30, input_shape=20, weight_initializer=init)
l.initialize(num_batch=1)
l.Wff.shape
(20, 30)
Instance of
brainpy.init.Initializer
If a brainpy.init.Initializer
instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:
l = bp.nn.Dense(num_unit=100, input_shape=20,
weight_initializer=bp.init.Normal(0.01))
l.initialize(num_batch=1)
l.Wff.shape
(20, 100)
The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).
None parameter
Some types of parameter variables can also be set to None
at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:
l = bp.nn.Dense(num_unit=100, input_shape=20, bias_initializer=None)
l.initialize(num_batch=1)
print(l.bias)
None
Calling the node#
The instantiation of a node build a input-to-output function mapping. To get the mapping output, you can directly call the created node.
l = bp.nn.Dense(num_unit=10, input_shape=20)
l.initialize()
l(bm.random.random((1, 20)))
JaxArray([[ 0.7788163 , 0.6352515 , 0.9846623 , 0.97518134,
-1.0947354 , 0.29821265, -0.9927582 , -0.00511351,
0.6623081 , 0.72418994]], dtype=float32)
l(bm.random.random((2, 20)))
JaxArray([[ 0.21428639, 0.5546448 , 0.5172446 , 1.2533414 ,
-0.54073226, 0.6578476 , -0.31080672, 0.25883573,
-0.0466502 , 0.50195456],
[ 0.91855824, 0.503054 , 1.1109638 , 0.707477 ,
-0.8442794 , -0.12064239, -0.81839114, -0.2828313 ,
-0.660355 , 0.20748737]], dtype=float32)
Moreover, JIT the created model is also applicable.
jit_l = bm.jit(l)
%timeit l(bm.random.random((2, 20)))
2.34 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jit_l(bm.random.random((2, 20)))
2.04 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
trainable
settings#
Setting the node to be trainable or non-trainable can be easily achieved. This is controlled by the trainable
argument when initializing a node.
For example, for a non-trainable dense layer, the weights and bias are JaxArray instances.
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=False)
l.initialize(num_batch=1)
l.Wff
JaxArray([[ 0.56564915, -0.70626205, 0.03569109],
[-0.10908064, -0.63869774, -0.37541717],
[-0.80857176, 0.22993006, 0.02752776],
[ 0.32151228, -0.45234612, 0.9239818 ]], dtype=float32)
When creating a layer with trainable setting, TrainVar
will be created for them and initialized automatically. For example:
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=True)
l.initialize(num_batch=1)
l.Wff
TrainVar([[-0.20390746, 0.7101851 , -0.2881384 ],
[ 0.07779109, -1.1979834 , 0.09109607],
[-0.41889605, 0.3983429 , -1.1674007 ],
[-0.14914905, -1.1085916 , -0.10857478]], dtype=float32)
Moreover, for a subclass of brainpy.nn.RecurrentNode
, the state
can be set to be trainable or not trainable by state_trainable
argument. When setting state_trainable=True
for an instance of brainpy.nn.RecurrentNode
, a new attribute .train_state will be created.
rnn = bp.nn.VanillaRNN(3, input_shape=(1,), state_trainable=True)
rnn.initialize(3)
rnn.train_state
TrainVar([0.7986958 , 0.3421112 , 0.24420719], dtype=float32)
Note the difference between the .train_state and the original .state:
.train_state has no batch axis.
When using
node.init_state()
ornode.initialize()
function, all values in the .state will be filled with .train_state.
rnn.state
Variable([[0.7986958 , 0.3421112 , 0.24420719],
[0.7986958 , 0.3421112 , 0.24420719],
[0.7986958 , 0.3421112 , 0.24420719]], dtype=float32)
Node Operations#
To form a large network, you need to know the supported node operations. In this section, we are going to talk about this.
import brainpy as bp
The Node instance supports the following basic node operations:
feedforward connection:
>>
,>>=
feedback connection:
<<
,<<=
merging:
&
or&=
concatenating:
[node1, node2, ...]
or(node1, node2, ...)
wraping a set of nodes:
{node1, node2, ...}
selection:
node[slice]
(like “node[1, 2, 3]”, “node[:10]”)
Feedforward operator#
Feedforward connection is the theme of the network construction. To declare a feedforward connection between two nodes, you can use the >>
operator.
Users can use node1 >> node2
to create a feedforward connection betweem two nodes. Or, ones can use node1 >>= node2
to in-place connect node2
.
i = bp.nn.Input(1)
r = bp.nn.VanillaRNN(10)
o = bp.nn.Dense(1)
model = i >> r >> o
model.plot_node_graph(fig_size=(6, 4),
node_size=1000)

Nodes can be combined in any way to create deeper structure. The >>
operator allows to compose nodes to form a sequential model. Data flows from node to node in a sequence. Below are examples of deep recurrent neural networks.
model = (
bp.nn.Input(1)
>>
bp.nn.VanillaRNN(10)
>>
bp.nn.VanillaRNN(20)
>>
bp.nn.VanillaRNN(10)
>>
bp.nn.Dense(1)
)
model.plot_node_graph(fig_size=(6, 4), node_size=500, layout='shell_layout')

Note
The feedforward connections cannot form a cycle. Otherwise, an error will be raised.
try:
model = i >> r >> o >> i
except Exception as e:
print(f'{e.__class__.__name__}: {e}')
ValueError: We detect cycles in feedforward connections. Maybe you should replace some connection with as feedback ones.
Feedback operator#
Feedback connections are important features of reservoir computing. Once a feedback connection is established between two nodes, when running on a timeseries, BrainPy will send the output of the sender, with a time delay of one time-step (however the way of the feedback can be customized by user settings).
To declare a feedback connection between two nodes, you can use the <<
operator.
model = (i >> r >> o) & (r << o)
model.plot_node_graph(fig_size=(4, 4), node_size=1000)

Merging operator#
The merging &
operator allows to merge models together. Merging two networks will create a new network model containing all nodes and all conenction edges in the two networks.
Some networks may have input-to-readout connections. This can be achieved using the merging operation &
.
model = (i >> r >> o) & (i >> o)
model.plot_node_graph(fig_size=(4, 4), node_size=1000)

Concatenating operator#
Concatenating operators []
and ()
will concatenate multiple nodes into one. It can be used in the sender side of a feedforward or feedback connection.
For above input-to-readout connections, we can rewrite it as:
model = [i >> r, i] >> o
# or
# model = (i >> r, i) >> o
model.plot_node_graph(fig_size=(4, 4), node_size=1000)

Note
Concatenating multiple nodes in the receiver side will cause errors.
# In the above network, "i" project to "r" and "o" simultaneously.
# However, we cannot express this node graph as
#
# i >> [r, o]
try:
model = i >> [r, o]
except Exception as e:
print(f'{e.__class__.__name__}: {e}')
ValueError: Cannot concatenate a list/tuple of receivers. Please use set to wrap multiple receivers instead.
Wraping operator#
Wrapping a set of nodes {}
means that these nodes are equal and they can make the same operation simultaneously.
For example, if the input node “i” project to recurrent node “r” and readout node “o” simultaneously, we can express this graph as
model = i >> {r, o}
model.plot_node_graph(fig_size=(4, 4), node_size=1000)

Similarly, if multiple nodes connect to a same node, we can wrap then first and then establish the connections.
model = {i >> r, i} >> o
model.plot_node_graph(fig_size=(4, 4), node_size=1000)

Selecting operator#
Sometimes, our input is just a subset of output of a node. For this situation, we can use selection node[]
operator.
For example, if we want decode a half of output of the recurrent node “r” by a readout node, and decode the other half of recurrent output by another readout node, we can express this graph as:
o1 = bp.nn.Dense(1)
o2 = bp.nn.Dense(2)
model = i >> r
model = (model[:, :5] >> o1) & (model[:, 5:] >> o2) # the first is the batch axis
model.plot_node_graph(fig_size=(5, 5), node_size=1000)

Network Running and Training#
To maker your model powerful, you need to train your created network models. In this section, we are going to talk about how to train and run your network models.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
import matplotlib.pyplot as plt
RNN structural runner RNNRunner
#
For a feedforward network, predicting the output of the network just needs to call the instantiated model:
model = ... # your created model
output = model(inputs)
To accelerate the model running, you can jit the model by
import brainpy.math as bm
model = bm.jit(model) # jitted model
output = model(inputs)
However, for the recurrent network model, you need to call the instantiated model multiple times along the time axis. However, looping in python is very inefficient. Instead, BrainPy provides structural runner brainpy.nn.RNNRunner
for the recurrent neural network running. Using brainpy.nn.RNNRunner
, the looping process will be jit compiled into machine code, approaching to the speed of native c++ code.
Here we have a reservoir model.
model = (bp.nn.Input(3) >>
bp.nn.Reservoir(100) >>
bp.nn.LinearReadout(3))
model.initialize()
And we have a Lorenz attractor data.
lorenz = bp.datasets.lorenz_series(100)
data = bm.hstack([lorenz['x'], lorenz['y'], lorenz['z']])
Our task is to predict the Lorenz data 5 time step ahead.
X, Y = data[:-5], data[5:]
Note, all nn
models in BrainPy must have a batch axis at the first dimension of the data.
# here batch size is 1
X = bm.expand_dims(X, axis=0)
Y = bm.expand_dims(Y, axis=0)
X.shape
(1, 99995, 3)
We can output the model predictions according to the input data simply with
runner = bp.nn.RNNRunner(model, jit=True)
predictions = runner.predict(X)
predictions.shape
(1, 99995, 3)
bp.losses.mean_squared_error(predictions, Y)
DeviceArray(260.12122, dtype=float32)
Without training, the mean squared error (MSE) between the prediction and the target is large. We need to train the network hoping it has the ability to produce the correct results.
Supported training algorithms#
Currently, BrainPy provides several kinds of methods to train recurrent neural networks, including
offline training algorithms, like ridge regression,
online training algorithms, like FORCE learning,
back-propagation based algorithms, like BPTT, etc.
The full list of the supported training algorithms please see the API documentation. Here we only talk about few of them.
RNN structural trainer RNNTrainer
#
RNNTrainer
is a structural trainer to train recurrent neural networks. Actually, it is a subclass of RNNRunner
. What’s different from RNNRunner
is that the former has one more function .fit()
to train the model.
The training data feeding into the .fit()
function can be a tuple/list of (X, Y)
pair, or a callable function which generate (x, y)
data pairs.
If the providing training data is the
(X, Y)
data pair,X
should be the input data which has the shape of(num_sample, num_time, ...)
,Y
should be the target data which has the shape of(num_sample, ...)
.If the training data is a callable function, it should generate a Python generator which yield the pair of
(X, Y)
data for training. For example,
# when calling this function,
# it will create a Python generator.
def train_data():
num_data = 10
for _ in range(num_data):
# The (X, Y) data pair should be:
# - "X" is a tensor has the shape of
# "(num_batch, num_time, num_feature)"
# - "Y" is a tensor has the shape of
# "(num_batch, num_time, num_feature)"
# or "(num_batch, num_feature)"
xs = bm.random.rand(1, 20, 2)
ys = bm.random.random((1, 20, 2))
yield xs, ys
However, all these data constraints can be released when you customize your training procedures. Please see Customization of a Network Training.
Offline training algorithms#
Offline learning means you train your network with all dataset at once. All supported offline learning algorithms are
bp.nn.algorithms.get_supported_offline_methods()
('ridge', 'linear', 'lstsq')
We will continue to update all offline learning methods. New advances please check the corresponding API documentation
Instantiating an offline learning method is simple. Once you have your network model, like the above reservoir model, you just need to provide this model into the brainpy.nn.OfflineTrainer
as
model = (bp.nn.Input(3) >>
bp.nn.Reservoir(100) >>
bp.nn.LinearReadout(3))
model.initialize()
trainer = bp.nn.OfflineTrainer(
model,
fit_method=bp.nn.algorithms.RidgeRegression(beta=1e-6)
# or
# fit_method=dict(name='ridge', beta=1e-6)
)
trainer
OfflineTrainer(target=Network(LinearReadout1, Input1, Reservoir1),
jit={'fit': True, 'predict': True},
fit_method=RidgeRegression(beta=1e-06))
Let’s train the created model with the Lorenz attractor data series.
trainer.fit([X, Y])
predict = trainer.predict(X, reset=True)
predict1 = bm.as_numpy(predict)
fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict1[0, :, 0], predict1[0, :, 1], predict1[0, :, 2])
plt.title('Trained with Ridge Regression')
plt.show()

Online training algorithms#
BrainPy also supports flexible online training methods. Online learning means you train the model from a sequence of data instances one at a time. The representative of online learning algorithm for recurrent neural network is the force learning. Here let’s try to train the above reservoir model with the force learning algorithm.
model = (bp.nn.Input(3) >>
bp.nn.Reservoir(100) >>
bp.nn.LinearReadout(3))
model.initialize()
trainer = bp.nn.OnlineTrainer(
model,
fit_method=bp.nn.algorithms.ForceLearning(alpha=0.1)
# or
# fit_method=dict(name='force', alpha=1e-1)
)
trainer
OnlineTrainer(target=Network(Input2, Reservoir2, LinearReadout2),
jit={'fit': True, 'predict': True},
fit_method=ForceLearning)
trainer.fit([X, Y])
predict2 = trainer.predict(X, reset=True)
predict2 = bm.as_numpy(predict2)
fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict2[0, :, 0], predict2[0, :, 1], predict2[0, :, 2])
plt.title('Trained with Force Learning')
plt.show()

Back-propagation algorithm#
In recent years, back-propagation has become a powerful method to train recurrent neural network. BrainPy also support trains networks with back-propagation algorithms.
reservoir = (bp.nn.Input(3) >>
bp.nn.Reservoir(100))
reservoir.initialize()
# The reservoir node is not trainable.
# Therefore, we generate the reservoir output
# data as the input data to train readout node.
runner = bp.nn.RNNRunner(reservoir)
projections = runner.predict(X)
# For linear readout node, it is not a recurrent node.
# There is no need to keep time axis.
# Therefore, we make the original time step as the sample size.
projections = projections[0]
targets = Y[0]
readout = bp.nn.Dense(3, input_shape=100)
readout.initialize()
# Training the readout node with the back-propagation method.
# Due to the Dense node is a feedforward node, we use BPTT trainer.
trainer = bp.nn.BPFF(readout,
loss=bp.losses.mean_squared_error,
optimizer=bp.optim.Adam(lr=1e-3))
trainer
BPFF(target=Dense(name=Dense0, forwards=((None, 100),),
feedbacks=None, output=(None, 3)),
jit={'fit': True, 'predict': True, 'loss': True},
loss=<function mean_squared_error at 0x0000021BC021BC10>,
optimizer=Adam(lr=Constant(0.001), beta1=0.9, beta2=0.999, eps=1e-08))
trainer.fit([projections, targets],
num_report=2000,
num_batch=64,
num_train=40)
Train 2000 steps, use 3.4620 s, train loss 31.42259
Train 4000 steps, use 2.4532 s, train loss 14.43684
Train 6000 steps, use 2.3870 s, train loss 10.64222
Train 8000 steps, use 2.4791 s, train loss 13.16424
Train 10000 steps, use 2.3759 s, train loss 7.0941
Train 12000 steps, use 2.3584 s, train loss 7.70877
Train 14000 steps, use 2.3648 s, train loss 8.33284
Train 16000 steps, use 2.4334 s, train loss 3.79623
Train 18000 steps, use 2.3502 s, train loss 3.86504
Train 20000 steps, use 2.3463 s, train loss 3.96748
Train 22000 steps, use 2.4486 s, train loss 3.88499
Train 24000 steps, use 2.3902 s, train loss 2.47998
Train 26000 steps, use 2.3854 s, train loss 1.69119
Train 28000 steps, use 2.3613 s, train loss 1.85288
Train 30000 steps, use 2.4531 s, train loss 1.77884
Train 32000 steps, use 2.3742 s, train loss 1.95193
Train 34000 steps, use 2.3862 s, train loss 1.6745
Train 36000 steps, use 2.4662 s, train loss 1.20792
Train 38000 steps, use 2.3957 s, train loss 1.55736
Train 40000 steps, use 2.3752 s, train loss 1.36623
Train 42000 steps, use 2.3872 s, train loss 1.09453
Train 44000 steps, use 2.4989 s, train loss 0.97422
Train 46000 steps, use 2.3895 s, train loss 0.70705
Train 48000 steps, use 2.4091 s, train loss 0.8673
Train 50000 steps, use 2.3833 s, train loss 1.12951
Train 52000 steps, use 2.4962 s, train loss 1.20924
Train 54000 steps, use 2.3950 s, train loss 0.79635
Train 56000 steps, use 2.3883 s, train loss 0.62906
Train 58000 steps, use 2.4581 s, train loss 0.91307
Train 60000 steps, use 2.4038 s, train loss 0.74997
Train 62000 steps, use 2.4042 s, train loss 1.04045
# plot the training loss
plt.plot(trainer.train_losses.numpy())
plt.show()

# Finally, let's make the full model in which
# reservoir node generates the high-dimensional
# projection data, and then the linear readout
# node readout the final value.
model = reservoir >> readout
model.initialize()
runner = bp.nn.RNNRunner(model)
predict3 = runner.predict(X)
predict3 = bm.as_numpy(predict3)
fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict3[0, :, 0], predict3[0, :, 1], predict3[0, :, 2])
plt.title('Trained with BPTT')
plt.show()

Node Customization#
To implement a custom node in BrainPy, you will have to write a Python class that subclasses brainpy.nn.Node
and implement several important methods.
import brainpy as bp
import brainpy.math as bm
from brainpy.tools.checking import check_shape_consistency
bp.math.set_platform('cpu')
Before we start, you need to know the minimal knowledge about the brainpy.nn.Node
. Please see the tutorial of Node Specification.
Customizing a feedforward node#
In general, the variable initialization and the logic computation of each node in brainpy.nn
module are separated from each other. If not, applying JIT compilation to these nodes will be difficult.
If your node only has feedforward connections,

you need to implement two functions:
init_ff()
: This function aims to initialize the feedforward connections and compute the output shape according to the givenfeedforward_shapes
.forward()
: This function implement the main computation logic of the node. It may calculate the new state of the node. But most importantly, this function shoud return the output value for feedforward data flow.
To show how this can be used, here is a node that multiplies its input by a matrix W
(much like a typical fully connected layer in a neural network would). This matrix is a parameter of the layer. The shape of the matrix will be (num_input, num_unit), where num_input is the number of input features and num_unit is the number of output features.
class DotNode(bp.nn.Node):
def __init__(self, num_unit, W_initializer=bp.initialize.Normal(), **kwargs):
super(DotNode, self).__init__(**kwargs)
self.num_unit = num_unit
self.W_initializer = W_initializer
def init_ff(self):
# This function should compute the output shape and
# the feedforward (FF) connections
# 1. First, due to multiple FF shapes, we need to know
# the total shape when all FF inputs are concatenated.
# Function "check_shape_consistency()" may help you
# solve this problem quickly.
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
# 2. Initialize the weight W
weight_shape = (sum(free_sizes), self.num_unit)
self.W = bp.nn.init_param(self.W_initializer, weight_shape)
# If the user want to train this node, we need mark the
# weight as a "brainpy.math.TrainVar"
if self.trainable:
self.W = bm.TrainVar(self.W)
# 3. Set the output shape
self.set_output_shape(unique_size + (self.num_unit,))
def forward(self, ff):
# 1. First, we concatenate all FF inputs
ff = bm.concatenate(ff, axis=-1)
# 2. Then, we multiply the input with the weight
return bm.dot(ff, self.W)
A few things are worth noting here: when overriding the constructor, we need to call the superclass constructor on the first line. This is important to ensure the node functions properly. Note that we pass **kwargs
- although this is not strictly necessary, it enables some other cool features, such as making it possible to give the layer a name:
DotNode(10, name='my_dot_node')
DotNode(name=my_dot_node, trainable=False, forwards=None, feedbacks=None,
output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)
Or, set this node trainable:
DotNode(10, trainable=True)
DotNode(name=DotNode0, trainable=True, forwards=None, feedbacks=None,
output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)
Once we create this DotNode
, we can connect multiple feedforward nodes to its instance.
l = DotNode(10)
i1 = bp.nn.Input(1, name='i1')
i2 = bp.nn.Input(2, name='i2')
i3 = bp.nn.Input(3, name='i3')
net = {i1, i2, i3} >> l
net.plot_node_graph(fig_size=(4, 4), node_size=2000)

net.initialize(num_batch=1)
# given an input, let's compute its output
net({'i1': bm.ones((1, 1)),
'i2': bm.zeros((1, 2)),
'i3': bm.random.random((1, 3))})
JaxArray([[-0.41227022, -1.2145127 , 1.2915486 , -1.7037894 ,
0.47149402, -1.9161812 , 1.3631151 , -0.4410456 ,
1.9460022 , 0.54992586]], dtype=float32)
Customizing a recurrent node#
If your node is a recurrent node, which means it has its own state
and has a self-to-self connection weights,

this time, you need to implement one more function:
init_state(num_batch)
: This function aims to initialize the Node state which depends on the batch size.
Furthermore, we recommend users’ recurren node inherit from brainpy.nn.RecurrentNode
. Because this will instruct BrainPy to know it is a node has recurrent connections.
Here, let’s try to implement a Vanilla RNN model.
class VRNN(bp.nn.RecurrentNode):
def __init__(self, num_unit,
wi_initializer=bp.init.XavierNormal(),
wr_initializer=bp.init.XavierNormal(), **kwargs):
super(VRNN, self).__init__(**kwargs)
self.num_unit = num_unit
self.wi_initializer = wi_initializer
self.wr_initializer = wr_initializer
def init_ff(self):
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
num_input = sum(free_sizes)
self.wi = bp.nn.init_param(self.wi_initializer, (num_input, self.num_unit))
self.wr = bp.nn.init_param(self.wr_initializer, (self.num_unit, self.num_unit))
if self.trainable:
self.wi = bm.TrainVar(self.wi)
self.wr = bm.TrainVar(self.wr)
def init_state(self, num_batch=1):
state = bm.zeros((num_batch, self.num_unit))
self.set_state(state)
def forward(self, ff):
ff = bm.concatenate(ff, axis=-1)
state = ff @ self.wi + self.state @ self.wr
self.state.value = state
return state
Customizing a node with feedbacks#
Creating a layer receiving multiple feedback inputs is the same with the feedforward connections.

Users need to implement one more function, that is:
init_fb()
: This function aims to initialize the feedback information, including the feedback connections, feedback weights, and others.
For the above DotNode
, if try to support feedback connection, you can define the model like:
class FeedBackDotNode(bp.nn.Node):
def __init__(self, num_unit, W_initializer=bp.initialize.Normal(), **kwargs):
super(FeedBackDotNode, self).__init__(**kwargs)
self.num_unit = num_unit
self.W_initializer = W_initializer
def init_ff(self):
# 1. FF shapes
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
# 2. Initialize the feedforward weight Wff
weight_shape = (sum(free_sizes), self.num_unit)
self.Wff = bp.nn.init_param(self.W_initializer, weight_shape)
if self.trainable:
self.Wff = bm.TrainVar(self.Wff)
# 3. Set the output shape
self.set_output_shape(unique_size + (self.num_unit,))
def init_fb(self):
# 1. FB shapes
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)
# 2. Initialize the feedback weight Wfb
weight_shape = (sum(free_sizes), self.num_unit)
self.Wfb = bp.nn.init_param(self.W_initializer, weight_shape)
if self.trainable:
self.Wfb = bm.TrainVar(self.Wfb)
def forward(self, ff, fb=None):
ff = bm.concatenate(ff, axis=-1)
res = bm.dot(ff, self.Wff)
if fb is None:
fb = bm.concatenate(fb, axis=-1)
res += bm.dot(fb, self.Wfb)
return res
Note the difference between DotNode
and FeedBackDotNode
. The forward()
function of the latter has one argument fb=None
, which means if this node has feedback connections, it will pass all feedback inputs to fb
argument.
Note
Feedback connecting to a node which do not support feedbacks will raise an error.
try:
DotNode(1) << bp.nn.Input(1)
except Exception as e:
print(e.__class__)
print(e)
<class 'ValueError'>
Establish a feedback connection to
DotNode(name=DotNode2, trainable=False, forwards=None, feedbacks=None,
output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)
is not allowed. Because this node does not support feedback connections.
FeedBackDotNode(1) << bp.nn.Input(1)
Network(FeedBackDotNode0, Input1)
Customizing a node with multiple behaviors#
Some nodes can have multiple behaviors. For example, a node implementing dropout should be able to be switched on or off. During training, we want it to apply dropout noise to its input and scale up the remaining values, but during evaluation we don’t want it to do anything.
For this purpose, the forward()
method takes optional keyword arguments (kwargs
). When forward()
is called to compute an expression for the output of a network, all specified keyword arguments are passed to the forward()
methods of all layers in the network.
class Dropout(bp.nn.Node):
def __init__(self, prob, seed=None, **kwargs):
super(Dropout, self).__init__(**kwargs)
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)
def init_ff(self):
assert len(self.feedback_shapes) == 1, 'Only support one feedforward input.'
self.set_output_shape(self.feedforward_shapes[0])
def forward(self, ff, **kwargs):
assert len(ff) == 1, 'Only support one feedforward input.'
if kwargs.get('train', True):
keep_mask = self.rng.bernoulli(self.prob, ff[0].shape)
return bm.where(keep_mask, ff[0] / self.prob, 0.)
else:
return ff[0]
Dropout
node only supports one feedforward input. Therefore we have some check at the beginning of init_ff()
and forward()
functions.
Customization of a Network Training#
Dynamics Analysis#
Low-dimensional Analyzers#
We have talked about model simulation and training for dynamical systems with BrainPy. In this tutorial, we are going to dive into how to perform automatic analysis for your defined systems.
As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:
phase plane analysis;
codimension 1 or codimension 2 bifurcation analysis;
bifurcation analysis of the fast-slow system.
BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
bp.math.enable_x64() # It's better to enable x64 when performing analysis
import numpy as np
import matplotlib.pyplot as plt
A simple case#
Here we test BrainPy with a simple case:
where \(x \in [-10, 10]\).
As known to us all, this functuon has multiple fixed points (\(\frac{dx}{dt} = 0\)) when \(I=0\).
xs = np.arange(-10, 10, 0.01)
plt.plot(xs, np.sin(xs))
plt.scatter([-3*np.pi, -1*np.pi, 1*np.pi, 3 * np.pi], np.zeros(4), s=80, edgecolors='y')
plt.scatter([-2*np.pi, 0, 2*np.pi], np.zeros(3), s=80, facecolors='none', edgecolors='r')
plt.axhline(0)
plt.show()

According to the dynamical theory, at the red hollow points, they are unstable; and for the solid ones, they are stable points.
Now let’s come back to BrainPy, and test whether BrainPy can give us the right answer.
As the analysis interfaces in BrainPy only receives ODEIntegrator or instance of DynamicalSystem, we first define an integrator with BrainPy (if you want to know how to define an ODE integrator, please refer to the tutorial of Numerical Solvers for ODEs):
@bp.odeint
def int_x(x, t, Iext):
return bp.math.sin(x) + Iext
This is a one-dimensional dynamical system. So we are trying to use brainpy.analysis.PhasePlane1D for phase plane analysis. The usage of phase plane analysis will be detailed in the following section. Now, we just focus on the following four arguments:
model: It specifies the target system to analyze. It can be a list/tuple of ODEIntegrator. However, it can also be an instance of DynamicalSystem. For
DynamicalSystem
argument, we will usemodel.ints().subset(bp.ode.ODEIntegrator)
to retrieve all instances of ODEIntegrator later.target_vars: It specifies the variables to analyze. It must be a dict with the format of
<var_name, var_interval>
, wherevar_name
is the variable name, andvar_interval
is the boundary of this variable.pars_update: Parameters to update.
resolutions: The resolution to evaluate the fixed points.
Let’s try it.
pp = bp.analysis.PhasePlane1D(
model=int_x,
target_vars={'x': [-10, 10]},
pars_update={'Iext': 0.},
resolutions=0.001
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)
I am creating vector fields ...
I am searching fixed points ...
Fixed point #1 at x=-9.42477796076938 is a stable point.
Fixed point #2 at x=-6.283185307179586 is a unstable point.
Fixed point #3 at x=-3.141592653589793 is a stable point.
Fixed point #4 at x=9.237056486678452e-19 is a unstable point.
Fixed point #5 at x=3.141592653589793 is a stable point.
Fixed point #6 at x=6.283185307179586 is a unstable point.
Fixed point #7 at x=9.42477796076938 is a stable point.

Yeah, absolutelty, brainpy.analysis.PhasePlane1D
gives us the right fixed points, and correctly evalutes the stability of these fixed points.
Phase plane is important, because it give us the intuitive understanding how the system evolves with the given parameters. However, in most cases where we care about how the parameters affect the system behaviors, we should make bifurcation analysis. brainpy.analysis.Bifurcation1D is a convenient interface to help you get the insights of how the dynamics of a 1D system changes with parameters.
Similar to brainpy.analysis.PhasePlane1D
, brainpy.analysis.Bifurcation1D
receives arguments like “model”, “target_vars”, “pars_update”, and “resolutions”. Besides, one more important argument “target_pars” should be provided, which specifies the range of the target parameter in bifurcation analysis.
Here, we systematically change the parameter “Iext” from 0 to 1.5. According to the bifurcation theory, we know this simple system has a fold bifurcation when \(I=1.0\). Because at \(I=1.0\), two fixed points collide with each other into a saddle point and then disappear. Does BrainPy’s analysis toolkit brainpy.analysis.Bifurcation1D
is capable of performing these analysis? Let’s make a try.
bif = bp.analysis.Bifurcation1D(
model=int_x,
target_vars={'x': [-10, 10]},
target_pars={'Iext': [0., 1.5]},
resolutions=0.001
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...

Once again, BrainPy analysis toolkit gives the right answer. It tells us how does the fixed points evolve when the parameter \(I\) is increasing.
It is worthy to note that bifurcation analysis in BrainPy is hard to find out the saddle point (when \(I=0\) for this system). This is because the saddle point at the bifurcation just exists at a moment. While the numerical method used in BrainPy analysis toolkit is almost impossible to evaluate the point exactly at the saddle. However, if the user has the minimal knowledge about the bifurcation theory, saddle point (the collision point of two fixed points) can be easily infered from the fixed point evolution.
BrainPy’s analysis toolkit is highly useful, especially when the mathematical equations are too complex to get analytical solutions. The example please refer to the tutorial Anlysis of A Decision Making Model.
Phase plane analysis#
Phase plane analysis is one of the most important techniques for studying the behavior of nonlinear systems, since there is usually no analytical solution for a nonlinear system. BrainPy can help users to plot phase plane of 1D systems or 2D systems. Specifically, we provides brainpy.analysis.PhasePlane1D and brainpy.analysis.PhasePlane2D. It can help to plot:
Nullcline: The zero-growth isoclines, such as \(g(x, y)=0\) and \(g(x, y)=0\).
Fixed points: The equilibrium points of the system, which are located at all of the nullclines intersect.
Vector field: The vector field of the system.
Limit cycles: The limit cycles.
Trajectories: A simulation trajectory with the given initial values.
We have talked about brainpy.analysis.PhasePlane1D
in above. Now we focus on brainpy.analysis.PhasePlane2D
by using a well-known neuron model FitzHugh-Nagumo model.
The FitzHugh-Nagumo model is given by:
There are two variables \(V\) and \(w\), so this is a two-dimensional system with three parameters \(a, b\) and \(\tau\).
For the system to analyze, users can define it by using the pure brainpy.odeint
or define it as a class of DynamicalSystem
. For this FitzHugh-Nagumo model, we define it as a class because later we will perform simulation to verify the analysis results.
class FitzHughNagumoModel(bp.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(FitzHughNagumoModel, self).__init__()
# parameters
self.a = 0.7
self.b = 0.8
self.tau = 12.5
# variables
self.V = bm.Variable(bm.zeros(1))
self.w = bm.Variable(bm.zeros(1))
self.Iext = bm.Variable(bm.zeros(1))
# functions
def dV(V, t, w, Iext=0.):
return V - V * V * V / 3 - w + Iext
def dw(w, t, V, a=0.7, b=0.8):
return (V + a - b * w) / self.tau
self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)
def update(self, _t, _dt):
self.V.value = self.int_V(self.V, _t, self.w, self.Iext, _dt)
self.w.value = self.int_w(self.w, _t, self.V, self.a, self.b, _dt)
self.Iext[:] = 0.
model = FitzHughNagumoModel()
Here we perform a phase plane analysis with parameters \(a=0.7, b=0.8, \tau=12.5\), and input \(I_{ext} = 0.8\).
pp = bp.analysis.PhasePlane2D(
model,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
pars_update={'Iext': 0.8},
resolutions=0.01,
)
# By defaut, nullclines will be plotted as points,
# while we can set the plot style as the line
pp.plot_nullcline(x_style={'fmt': '-'}, y_style={'fmt': '-'})
# Vector field can plotted as two ways:
# - plot_method="streamplot" (default)
# - plot_method="quiver"
pp.plot_vector_field()
# There are many ways to search fixed points.
# By default, it will use the nullcline points of the first
# variable ("V") as the initial points to perform fixed point searching
pp.plot_fixed_point()
# Trajectory plotting receives the setting of the initial points.
# There may be multiple trajectories, therefore the initial points
# should be provived as a list/tuple/numpy.ndarray/JaxArray
pp.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
# show the phase plane figure
pp.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating vector fields ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 866 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 V=-0.2738719079879798, w=0.5329731346879486 is a unstable node.
I am plot trajectory ...

We can see an unstable-node at the point (\(V=-0.27, w=0.53\)) inside a limit cycle.
We can run a simulation with the same parameters and initial values to verify the periodic activity that correspond to the limit cycle.
runner = bp.StructRunner(model, monitors=['V', 'w'], inputs=['Iext', 0.8])
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)

Understanding settings#
There are several key settings needed to understand.
resolutions
#
resolutions
is one of the most important parameters in PhasePlane and Bifurcation analysis toolkits of BrainPy. It is very important because it has a profound impact on the efficiency of model analysis.
We can set resolutions
with the following ways.
None. If we detect there is no resolution setting for any variable, the corresponding resolution for this variable will be \(\frac{\mathrm{max\_value} - \mathrm{min\_value}}{20}\).
A float. It sets a same resolution for each target variable and parameter.
A dict. Specify different resolutions for individual variable/parameter. It can be a float, or a vector with the format of JaxArray or numpy.ndarray.
Enabling set resolutions
with a tensor will give the user the maximal flexibility. Usually, the numerical alalysis does not work well at inflection points. Therefore, we can increase the granularity near the inflection points. For example, if there is an inflextion point at \(1\), we can set the resolution with:
r1 = bm.arange(0.00, 0.95, 0.01)
r2 = bm.arange(0.95, 1.01, 0.001)
r3 = bm.arange(1.05, 1.50, 0.01)
resolution = bm.concatenate([r1, r2, r3])
Tips: For bifurcation analysis, usually we need set a small resolution for parameters, leaving the resolutions of variables as the default. Please see in the following examples.
vars
and pars
#
What can be set as variables *_vars
or parameters *_pars
(such as target_vars
or target_pars
) for further analysis? Actually, the variables and parameters are recognized as the same with the programming paradigm of ODE numerical integrators. Simply speaking, the arguments before t
will be defined as variables, while arguments after t
will be parameters.
BrainPy’s analysis toolkit only support one variable in one differential equation. It cannot analyze the joint differential equation in which multiple variables are defined in the same function.
Moreover, the low-dimensional analyzers in BrainPy cannot analyze dynamical system depends on time \(t\).
Bifurcation analysis#
Nonlinear dynamical systems are characterized by its parameters. When the parameter changes, the system’s behavior will change qualitatively. Therefore, we take care of how the system changes with the smooth change of parameters.
Codimension 1 bifurcation analysis
We will first see the codimension 1 bifurcation anlysis of the model. For example, we vary the input \(I_{ext}\) between 0 to 1 and see how the system change it’s stability.
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
target_pars={'Iext': [0., 1.]},
resolutions={'Iext': 0.002},
)
# "num_rank" specifies the number of initial poinits for
# fixed point optimization under a set of parameters
analyzer.plot_bifurcation(num_rank=10)
# show figure
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 5000 candidates
I am trying to filter out duplicate fixed points ...
Found 500 fixed points.


Codimension 2 bifurcation analysis
We simulaneously change \(I_{ext}\) and parameter \(a\).
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars=dict(V=[-3, 3], w=[-3., 3.]),
target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
resolutions={'a': 0.01, 'Iext': 0.01},
)
analyzer.plot_bifurcation(num_rank=10, tol_aux=1e-9)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 50000 candidates
I am trying to filter out duplicate fixed points ...
Found 5000 fixed points.


Fast-slow system bifurcation#
BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.analysis.FastSlow1D and brainpy.analysis.FastSlow2D. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.
Fast-slow bifurcation methods are very usefull in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:
First of all, let’s define the Hindmarsh–Rose model with BrainPy.
a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9
@bp.odeint
def int_x(x, t, y, z, Isyn):
return y - a * x ** 3 + b * x * x - z + Isyn
@bp.odeint
def int_y(y, t, x):
return c - d * x * x - y
@bp.odeint
def int_z(z, t, x):
return r * (s * (x - x_r) - z)
We now can start to analysis the underlying bifurcation mechanism.
analyzer = bp.analysis.FastSlow2D(
[int_x, int_y, int_z],
fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
slow_vars={'z': [-5., 5.]},
pars_update={'Isyn': 0.5},
resolutions={'z': 0.01}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 20000 candidates
I am trying to filter out duplicate fixed points ...
Found 1232 fixed points.


Advanced tutorial: how does the analysis works#
In this section, we provide a basic tutorial to understand how does the brainpy.analysis.LowDimAnalyzer
works.
Terminology#
Given the above FitzHugh-Nagumo model, we define an analyzer,
analyzer = bp.analysis.PhasePlane2D(
[model.int_V, model.int_w],
target_vars={'V': [-3, 3], 'w': [-3., 3.]},
resolutions=0.01,
)
In this instance of brainpy.analysis.LowDimAnalyzer
, we use the following terminologies.
x_var and y_var are defined by the order of the user setting. If the user sets the “target_vars” as “{‘V’: …, ‘w’: …}”,
x_var
andy_var
will be “V” and “w” respectively. Otherwise, if “target_vars”=”{‘w’: …, ‘V’: …}”,x_var
andy_var
will be “w” and “V” respectively.
analyzer.x_var, analyzer.y_var
('V', 'w')
fx and fy are defined as differential equations of
x_var
andy_var
respectively, i.e.,
fx
is
def dV(V, t, w, Iext=0.):
return V - V * V * V / 3 - w + Iext
fy
is
def dw(w, t, V, a=0.7, b=0.8):
return (V + a - b * w) / self.tau
analyzer.F_fx, analyzer.F_fy
(<CompiledFunction of <function f_without_jaxarray_return.<locals>.f2 at 0x000002240FC114C0>>,
<CompiledFunction of <function f_without_jaxarray_return.<locals>.f2 at 0x000002240FC118B0>>)
int_x and int_y are defined as integral functions of the differential equations for
x_var
andy_var
respectively.
analyzer.F_int_x, analyzer.F_int_y
(functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x000002240FC11EE0>),
functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x000002240FC11F70>))
x_by_y_in_fx and y_by_x_in_fx: They denote that
x_var
andy_var
can be separated from each other in “fx” nullcline function. Specifically,x_by_y_in_fx
ory_by_x_in_fx
denotes \(x = F(y)\) or \(y = F(x)\) accoording to \(f_x=0\) equation. For example, in the above FitzHugh-Nagumo model, \(w\) can be easily represented by \(V\) when \(\mathrm{dV(V, t, w, I_{ext})} = 0\), i.e.,y_by_x_in_fx
is \(w= V - V ^3 / 3 + I_{ext}\).
Similarly, x_by_y_in_fy (\(x=F(y)\)) and y_by_x_in_fy (\(y=F(x)\)) denote
x_var
andy_var
can be separated from each other in “fy” nullcline function. For example, in the above FitzHugh-Nagumo model,y_by_x_in_fy
is \(w= \frac{V + a}{b}\), andx_by_y_in_fy
is \(V= b * w - a\).
x_by_y_in_fx
,y_by_x_in_fx
,x_by_y_in_fy
andy_by_x_in_fy
can be set in theoptions
argument.
Mechanism for 1D system analysis#
In order to understand the adavantages and disadvantages of BrainPy’s analysis toolkit, it is better to know the minimal mechanism how brainpy.analysis
works.
The automatic model analysis in BrainPy heavily relies on numerical optimization methods, including Brent’s method and BFGS method. For example, for the above one-dimensional system (\(\frac{dx}{dt} = \mathrm{sin}(x) + I\)), after the user sets the resolution to 0.001
, we will get the evaluation points according to the variable boundary [-10, 10]
.
bp.math.arange(-10, 10, 0.001)
JaxArray(DeviceArray([-10. , -9.999, -9.998, ..., 9.997, 9.998, 9.999], dtype=float64))
Then, BrainPy filters out the candidate intervals in which the roots lie in. Specifically, it tries to find all intervals like \([x_1, x_2]\) where \(f(x_1) * f(x_2) \le 0\) for the 1D system \(\frac{dx}{dt} = f(x)\).
For example, the following two points which have opposite signs are candidate points we want.
def plot_interval(x0, x1, f):
xs = np.linspace(x0, x1, 100)
plt.plot(xs, f(xs))
plt.scatter([x0, x1], f(np.asarray([x0, x1])), edgecolors='r')
plt.axhline(0)
plt.show()
plot_interval(-0.001, 0.001, lambda x: np.sin(x))

According to the intermediate value theorem, there must be a solution between \(x_1\) and \(x_2\) when \(f(x_1) * f(x_2) \le 0\).
Based on these candidate intervals, BrainPy uses Brent’s method to find roots \(f(x) = 0\). Further, after obtain the value of the root, BrainPy uses automatic differentiation to evaluate the stability of each root solution.
Overall, BrainPy’s analysis toolkit shows significant advantages and disadvantages.
Pros: BrainPy uses numerical methods to find roots and evaluate their stabilities, it does not case about how complex your function is. Therefore, it can apply to general problems, including any 1D and 2D dynamical systems, and some part of low-dimensional (\(\ge 3\)) dynamical systems (see later sections). Especially, BrainPy’s analysis toolkit is highly useful when the mathematical equations are too complex to get analytical solutions (the example please refer to the tutorial Anlysis of A Decision Making Model).
Cons: However, numerical methods used in BrainPy are hard to find fixed points only exist at a moment. Moreover, when resolution
is small, there will be large amount of calculating. Users should pay attention to designing suitable resolution settings.
Mechanism for 2D system analysis#
plot_vector_field()
Plotting vector field is simple. We just need to evaluate the values of each differential equation.
plot_nullcline()
Nullclines are evaluated through the Brent’s methods. In order to get all \((x, y)\) values that satisfy fx=0
(i.e., \(f_x(x, y) = 0\)), we first fix \(y=y_0\), then apply Brent optimization to get all \(x'\) that satisfy \(f_x(x', y_0) = 0\) (alternatively, we can fix \(x\) then optimize \(y\)). Therefore, we will perform Brent optimization many times, because we will iterate over all \(y\) value according to the resolution setting.
plot_fixed_points()
The fixed point finding in BrainPy relies on BFGS method. First, we define an auxiliary function \(L(x, t)\):
\(L(x, t)\) is always bigger than 0. We use BFGS optimization to get all local minima. Finally, we filter out the minima whose losses are smaller than \(1e^{-8}\), and we choose them as fixed points.
For this method, how to choose the initial points to perform optimization is the challege, especially when the parameter resolutions are small. Generally, there are four methods provided in BrainPy.
fx-nullcline: Choose the points in “fx” nullcline as the initial points for optimization.
fy-nullcline: Choose the points in “fy” nullcline as the initial points for optimization.
nullclines: Choose both the points in “fx” nullcline and “fy” nullcline as the initial points for optimization.
aux_rank: For a given set of parameters, we evaluate loss function at each point according to the resolution setting. Then we choose the first
num_rank
(default is 100) points which have the smallest losses.
However, if users provide one of functions of x_by_y_in_fx
, y_by_x_in_fx
, x_by_y_in_fy
and y_by_x_in_fy
. Things will become very simple, because we can change the 2D system as a 1D system, then we only need to optimzie the fixed points by using our favoriate Brent optimization.
For the given FitzHugh-Nagumo model, we can set
analyzer = bp.analysis.Bifurcation2D(
model,
target_vars=dict(V=[-3, 3], w=[-3., 3.]),
target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
resolutions={'a': 0.01, 'Iext': 0.01},
options={bp.analysis.C.y_by_x_in_fy: (lambda V, a=0.7, b=0.8: (V + a) / b)}
)
analyzer.plot_bifurcation()
analyzer.show_figure()
I am making bifurcation analysis ...
I am trying to find fixed points by brentq optimization ...
I am trying to filter out duplicate fixed points ...
Found 5000 fixed points.


References#
[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.
[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.
[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.
High-dimensional Analyzers#
It’s hard to analyze high-dimensional systems. However, we have to analyze high-dimensional systems.
Here, based on numerical optimization methods, BrainPy provides brainpy.analysis.SlowPointFinder to help users find slow points (or fixed points) [1] for your high-dimensional dynamical systems.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
What are slow points?#
For the given system,
we wish to find values \(x^∗\) around which the system is approximately linear. Using Taylor series expansion, we have
We want the first derivative term (i.e., the linear term) to be dominant, which means \(f(x^*) = 0\) or \(f(x^*) \approx 0\).
For \(f(x^*) \approx 0\) which is nonzero but small, we call the point \(x^*\) a slow point.
More specially, if \(f(x^*) = 0\), \(x^*\) is a fixed point.
How to find slow points?#
In order to find slow points, we can first define an auxiliary scalar function for your continous system \(\dot{x} = f(x)\),
Or, if your system is discrete \(x_n = f(x_{n-1})\), the auxiliary scalar function can be defined as
If \(x^*\) is a slow point, \(p(x^*) \to 0\).
Then, by minimizing the scalar function \(p(x)\), we can get the candidate points for slow points and for further linearization. For the linear system, it’s stability is evaluated by the eigenvalues of Jacobian matrix.
Here, BrainPy provides brainpy.analysis.SlowPointFinder. It receives a function which f_cell
defines \(f(x)\), and f_type
which specify the type of the function (it can be “continuous” or “discrete”). Then, brainpy.analysis.SlowPointFinder
can help you:
optimize to find the fixed/slow points with gradient descent algorithms (
find_fps_with_gd_method()
) or nonlinear optimization solver (find_fps_with_opt_solver()
)exclude any fixed points whose losses are above threshold:
filter_loss()
exclude any non-unique fixed points according to a tolerance:
keep_unique()
exclude any far-away “outlier” fixed points:
exclude_outliers()
computing the jacobian matrix for the given fixed/slow points:
compute_jacobians()
Example 1: Decision Making Model#
brainpy.analysis.SlowPointFinder
is aimed to find slow/fixed points of high-dimensional systems. Of course, it can optimize to find fixed points of low-dimensional systems. We take the 2D decision making system as an example.
# parameters
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154
JE = 0.3725 # self-coupling strength [nA]
JI = -0.1137 # cross-coupling strength [nA]
JAext = 0.00117 # Stimulus input strength [nA]
mu = 20. # Stimulus firing rate [spikes/sec]
coh = 0.5 # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2
def step(s):
ds1 = int_s1.f(s[0], 0., s[1])
ds2 = int_s2.f(s[1], 0., s[0])
return bm.asarray([ds1.value, ds2.value])
We first use brainpy.analysis.PhasePlane2D to get the standard answer.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
resolutions=0.001,
)
analyzer.plot_fixed_point(select_candidates='aux_rank', with_plot=False)
I am searching fixed points ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 100 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.28276315331459045, s2=0.40635165572166443 is a saddle node.
#2 s1=0.013946513645350933, s2=0.6573889851570129 is a stable node.
#3 s1=0.7004519104957581, s2=0.004864314571022987 is a stable node.
Then, let’s check whether the high-dimensional analyzer also works.
finder = bp.analysis.SlowPointFinder(f_cell=step)
finder.find_fps_with_gd_method(
candidates=bm.random.random((1000, 2)), tolerance=1e-5, num_batch=200,
opt_setting=dict(method=bm.optimizers.Adam,
lr=bm.optimizers.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(1e-5)
finder.keep_unique()
Optimizing with Adam to find fixed points:
Batches 1-200 in 0.52 sec, Training loss 0.0576312058
Batches 201-400 in 0.52 sec, Training loss 0.0049517932
Batches 401-600 in 0.53 sec, Training loss 0.0007580096
Batches 601-800 in 0.52 sec, Training loss 0.0001687836
Batches 801-1000 in 0.51 sec, Training loss 0.0000421500
Batches 1001-1200 in 0.52 sec, Training loss 0.0000108371
Batches 1201-1400 in 0.52 sec, Training loss 0.0000027990
Stop optimization as mean training loss 0.0000027990 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
Kept 962/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
Kept 3/962 unique fixed points with uniqueness tolerance 0.025.
finder.fixed_points
array([[0.7004518 , 0.00486438],
[0.28276336, 0.40635186],
[0.01394662, 0.65738887]], dtype=float32)
Yeah, the fixed points found by brainpy.analysis.PhasePlane2D
and brainpy.analysis.SlowPointFinder
are nearly the same.
Example 2: Continuous-attractor Neural Network#
Continuous-attractor neural network [2] proposed by Si Wu is a special model which has a line of attractors.
class CANN1D(bp.NeuGroup):
def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-bm.pi, z_max=bm.pi, name=None):
super(CANN1D, self).__init__(size=num, name=name)
# parameters
self.tau = tau # The synaptic time constant
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, num) # The encoded feature values
self.rho = num / self.z_range # The neural density
self.dx = self.z_range / num # The stimulus density
# variables
self.u = bm.Variable(bm.zeros(num))
self.input = bm.Variable(bm.zeros(num))
# The connection matrix
self.conn_mat = self.make_conn(self.x)
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, u, t, Iext):
r1 = bm.square(u)
r2 = 1.0 + self.k * bm.sum(r1)
r = r1 / r2
Irec = bm.dot(self.conn_mat, r)
du = (-u + Irec + Iext) / self.tau
return du
def dist(self, d):
d = bm.remainder(d, self.z_range)
d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
return d
def make_conn(self, x):
assert bm.ndim(x) == 1
x_left = bm.reshape(x, (-1, 1))
x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
d = self.dist(x_left - x_right)
Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
return Jxx
def get_stimulus_by_pos(self, pos):
return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def update(self, _t, _dt):
self.u[:] = self.integral(self.u, _t, self.input)
self.input[:] = 0.
def cell(self, u):
return self.derivative(u, 0., 0.)
def visualize_fixed_points(fps, plot_ids=(0,), xs=None):
for i in plot_ids:
if xs is None:
plt.plot(fps[i], label=f'FP-{i}')
else:
plt.plot(xs, fps[i], label=f'FP-{i}')
plt.legend()
plt.show()
cann = CANN1D(num=512, k=0.1, A=30)
These attractors is a series of bumps. Therefore we can initialize our candidate points with noisy bumps.
candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1)))
candidates += bm.random.normal(0., 0.01, candidates.shape)
finder = bp.analysis.SlowPointFinder(f_cell=cann.cell)
finder.find_fps_with_opt_solver(candidates)
finder.filter_loss(1e-6)
finder.keep_unique()
Optimizing to find fixed points:
Found 629 fixed points from 629 initial points.
Excluding fixed points with squared speed above tolerance 1e-06:
Kept 357/629 fixed points with tolerance under 1e-06.
Excluding non-unique fixed points:
Kept 357/357 unique fixed points with uniqueness tolerance 0.025.
pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points)
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()

visualize_fixed_points(finder.fixed_points, plot_ids=(10, 20, 30, 40, 50, 60, 70, 80), xs=cann.x)

num = 4
J = finder.compute_jacobians(finder.fixed_points[:num])
for i in range(num):
eigval, eigvec = np.linalg.eig(np.asarray(J[i]))
plt.figure()
plt.scatter(np.real(eigval), np.imag(eigval))
plt.plot([0, 0], [-1, 1], '--')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.show()




More examples of dynamics analysis, for example, analyzing the fixed points in a recurrent neural network, please see BrainPy Examples.
References#
[1] Sussillo, D. , and O. Barak . “Opening the Black Box: Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks.” Neural computation 25.3(2013):626-649.
[2] Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.
Analysis of a Decision-making Model#
In this section, we are going to use the low-dimensional analyzers to make phase plane and bifurcation analysis for the decision making model proposed by (Wong & Wang) [1].
Decision making model#
This model considers two excitatory neural assemblies, populations 1 and 2 , that compete with each other through a shared pool of inhibitory neurons. In our analysis, we use the following model equations.
Let \(r_1\) and \(r_2\) be firing rates of E and I populations, and the total synaptic input current \(I_i\) and the resulting firing rate \(r_i\) of the neural population \(i\) obey the following input-output relationship (\(F - I\) curve):
which captures the current-frequency function of a leaky integrate-and-fire neuron. The parameter values are \(a\) = 270 Hz/nA, \(b\) = 108 Hz, \(d\) = 0.154 sec.
Assume that the synaptic drive variables’ \(S_1\) and \(S_2\) obey
where \(\gamma\) = 0.641. The net current into each population is given by
The synaptic time constant is \(\tau_s\) = 100 ms (NMDA time consant). The synaptic coupling strengths are \(J_E\) = 0.2609 nA, \(J_I\) = -0.0497 nA, and \(J_{ext}\) = 0.00052 nA. Stimulus-selective inputs to populations 1 and 2 are governed by unitless parameters \(\mu_1\) and \(\mu_2\), respectively.
For the decision-making paradigm, the input rates \(\mu_1\) and \(\mu_2\) are determined by the stimulus coherence \(c'\) which ranges between 0 (0%) and 1 (100%):
import brainpy as bp
import brainpy.math as bm
bp.math.enable_x64()
bp.math.set_platform('cpu')
Parameters#
gamma = 0.641 # Saturation factor for gating variable
tau = 0.1 # Synaptic time constant [sec]
a = 270. # Hz/nA
b = 108. # Hz
d = 0.154 # sec
I0 = 0.3255 # background current [nA]
JE = 0.2609 # self-coupling strength [nA]
JI = -0.0497 # cross-coupling strength [nA]
JAext = 0.00052 # Stimulus input strength [nA]
Ib = 0.3255 # The background input
Model implementation#
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2
Phase plane analysis#
The advantage of the reduced model is that we can understand what dynamical behaviors the model generate for a particular parmeter set using phase-plane analysis and the explore how this behavior changed when the model parameters are variaed (bifurcation analysis).
To this end, we will use brainpy.analysis
module.
We construct the phase portraits of the reduced model for different stimulus inputs (see Figure 4 and Figure 5 in (Wong & Wang, 2006) [1]).
No stimulus: \(\mu_0 =0\) Hz. In the absence of a stimulus, the two nullclines intersect with each other five times, producing five steady states, of which three are stable (attractors) and two are unstable
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 0.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 5 fixed points.
#1 s1=0.31384529046352583, s2=0.055785349496809286 is a saddle node.
#2 s1=0.5669871605297268, s2=0.03189141971571585 is a stable node.
#3 s1=0.10265144582200994, s2=0.10265095098913903 is a stable node.
#4 s1=0.05578534267632982, s2=0.3138449310808734 is a saddle node.
#5 s1=0.03189144636489117, s2=0.5669870352865436 is a stable node.

Symmetric stimulus: \(\mu_0=30\) Hz, \(c'=0\). When a stimulus is applied, the phase space of the model is reconfigured. The spontaneous state vanishes. At the same time, a saddle-type unstable steady state is created that separates the two asymmetrical attractors.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 0.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.658694222824541, s2=0.051807224657553524 is a stable node.
#2 s1=0.42445578984858473, s2=0.42445562837314144 is a saddle node.
#3 s1=0.05180717720080611, s2=0.6586942355713473 is a stable node.

Biased stimulus: \(\mu_0=30\) Hz, \(c' = 0.14\) (14 % coherence). The phase space changes when a weak motion stimulus is presented. The phase space is no longer symmetrical: the attractor state s1 (correct choice) has a larger basin of attraction than attractor s2.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 0.14},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 s1=0.6679776160203832, s2=0.04583015706228367 is a stable node.
#2 s1=0.38455860789855467, s2=0.45363090352898155 is a saddle node.
#3 s1=0.059110032802350915, s2=0.6481046659437734 is a stable node.

Stimulus to one population only: \(\mu_0=30\) Hz, \(c'=1.\) (100 % coherence). When \(c'\) is sufficiently large, the saddle steady state annihilates with the less favored attractor, leaving only one choice attractor.
analyzer = bp.analysis.PhasePlane2D(
model=[int_s1, int_s2],
target_vars={'s1': [0, 1], 's2': [0, 1]},
pars_update={'mu': 30., 'coh': 1.},
resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
x_style={'fmt': '-'},
y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1212 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 s1=0.7092805209334904, s2=0.0239636630419946 is a stable node.

Bifurcation analysis#
To see how the ohase portrait of the system changed when we chang the stimulus current, we will generate a bifucation diagram for the reduced model. On the bifurcation diagram the fixed points of the model are shown as a function of a changing parameter.
In the next, we generate bifurcation diagrams with the different parameters.
Fix the coherence \(c'=0\), vary the stimulus strength \(\mu_0\). See Figure 10 in (Wong & Wang, 2006) [1].
analyzer = bp.analysis.Bifurcation2D(
model=[int_s1, int_s2],
target_vars={'s1': [0., 1.], 's2': [0., 1.]},
target_pars={'mu': [-30., 90.]},
pars_update={'coh': 0.},
resolutions={'mu': 0.2},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 30000 candidates
I am trying to filter out duplicate fixed points ...
Found 1744 fixed points.


Fix the stimulus strength \(\mu_0 = 30\) Hz, vary the coherence \(c'\).
analyzer = bp.analysis.Bifurcation2D(
model=[int_s1, int_s2],
target_vars={'s1': [0., 1.], 's2': [0., 1.]},
target_pars={'coh': [0., 1.]},
pars_update={'mu': 30.},
resolutions={'coh': 0.005},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 10000 candidates
I am trying to filter out duplicate fixed points ...
Found 475 fixed points.


References#
[1] Wong K-F and Wang X-J (2006). A recurrent network mechanism for time integration in perceptual decisions. J. Neurosci 26, 1314-1328.
Numerical Solvers for Ordinary Differential Equations#
Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to numerical methods for solving such differential equations.
This section will illustrate how to define ordinary differential quations (ODEs) and how to define the numerical integration methods for ODEs in BrainPy.
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
bm.set_platform('cpu')
%matplotlib inline
How to define ODE functions?#
BrainPy provides a convenient and intuitive way to define ODE systems. For the ODEs
we can define them in a Python function:
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
where t
denotes the current time, x
and y
passed before t
denote the dynamical variables, and p1
and p2
after t
denote the parameters needed in this system. In the function body, the derivative f1
and g1
can be customized by the user’s need. Finally, the corresponding derivatives dx
and dy
are returned in the same order as that of the variables in the function arguments.
For each variabl, it can be a scalar (var_type = bp.integrators.SCALAR_VAR
), a vector/matrix (var_type = bp.integrators.POP_VAR
), or a system (var_type = bp.integrators.SYSTEM_VAR
). The “system” means that the argument x
denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:
def diff(xy, t, p1, p2):
x, y = xy
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return bm.array([dx, dy])
How to define the numerical integration for ODEs?#
After the definition of ODE functions, it is very easy to define the numerical integration for these functions. We just need to put a decorator bp.odeint
above the ODE function.
@bp.odeint
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
After wrapping it by bp.odeint
, the function becomes an instance of ODEintegrator
.
isinstance(diff, bp.ode.ODEIntegrator)
True
bp.odeint
receives several arguments:
“method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.
diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x17fe2b8da30>
“dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.
diff.dt
0.1
“show_code”: bool, to indicate whether to show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.
@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
x_new = x + dx_k1 * dt * 1
y_new = y + dy_k1 * dt * 1
return x_new, y_new
{'f': <function diff at 0x0000017FE2BABEE0>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x17fe2bc7700>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
dx = f1(x, t, y, p1)
dy = g1(y, t, x, p2)
return dx, dy
diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
dx_k1, dy_k1 = f(x, y, t, p1, p2)
k2_x_arg = x + dt * dx_k1 * 0.5
k2_y_arg = y + dt * dy_k1 * 0.5
k2_t_arg = t + dt * 0.5
dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
k3_x_arg = x + dt * dx_k2 * 0.5
k3_y_arg = y + dt * dy_k2 * 0.5
k3_t_arg = t + dt * 0.5
dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
k4_x_arg = x + dt * dx_k3
k4_y_arg = y + dt * dy_k3
k4_t_arg = t + dt
dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
return x_new, y_new
{'f': <function diff at 0x0000017FE2BCA430>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe2bc7e80>
Two Examples#
Example 1: FitzHugh–Nagumo model#
Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:
For this FHN model, we can code it in BrainPy like this:
@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
dw = (V + a - b * w) / tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,
a = 0.7; b = 0.8; tau = 12.5; Iext = 1.
the solution of the FHN model between 0 and 100 ms can be approximated by
hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
V, w = integral(V, w, t, Iext, a, b, tau)
hist_V.append(V)
plt.plot(hist_times, hist_V)
plt.show()

This manual loop in Python code is usually slow. In BrainPy, we provide a structural runner for integrators: brainpy.integrators.IntegratorRunner
, which can benefit from the JIT compilation.
runner = bp.integrators.IntegratorRunner(
integral,
monitors=['V'],
inits=dict(V=0., w=0.),
args=dict(a=a, b=b, tau=tau, Iext=Iext),
dt=0.01
)
runner.run(100.)
plt.plot(runner.mon.ts, runner.mon.V)
plt.show()

Example 2: Hodgkin–Huxley model#
Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h
) are used for modeling the initiation and propagation of the action potential. Specifically, they are governed by the following equations:
In BrainPy, such dynamical system can be coded as:
@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt, dmdt, dhdt, dndt
Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:
Iext = 10.; ENa = 50.; EK = -77.; EL = -54.387
C = 1.0; gNa = 120.; gK = 36.; gL = 0.03
runner = bp.integrators.IntegratorRunner(
integral,
monitors=list('Vmhn'),
inits=[0., 0., 0., 0.],
args=dict(Iext=Iext, gNa=gNa, ENa=ENa, gK=gK, EK=EK, gL=gL, EL=EL, C=C),
dt=0.01
)
runner.run(100.)
plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()

Provided ODE Numerical Solvers#
BrainPy
provides several types of numerical methods for ODEs, including explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler methods.
1. Explicit Runge-Kutta (RK) methods for ODEs#
The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:
Methods |
Keywords |
---|---|
euler |
|
midpoint |
|
heun2 |
|
ralston2 |
|
rk2 |
|
rk3 |
|
rk4 |
|
heun3 |
|
ralston3 |
|
ssprk3 |
|
ralston4 |
|
rk4_38rule |
Users can utilize these methods by specifying the method
option in brainpy.odeint()
with their corresponding keyword. For example:
@bp.odeint(method='rk4')
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4950d60>
Or, you can directly instance your favorite integrator:
@bp.ode.RK4
def int_v(v, t, p):
# do something
return v
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4950670>
def derivative(v, t, p):
# do something
return v
int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4b3af70>
2. Adaptive Runge-Kutta (RK) methods for ODEs#
The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if \(error > tol\), then replace \(dt\) with \(dt_{new}\) and repeat the step. Therefore, adaptive RK methods allow a varied step size. In BrainPy, the following adaptive RK methods are provided in BrainPy:
Methods |
keywords |
---|---|
rkf45 |
|
rkf12 |
|
rkdp |
|
ck |
|
bs |
|
heun_euler |
In default, the above methods are not adaptive, unless users provide a keyword adaptive=True
in brainpy.odeint()
. When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt
will be appended in the functional arguments. Moreover, the tolerance tol
for stepsize adjustment can also be modified. Let’s take the Lorenz system as the example:
# adaptively adjust step-size
@bm.jit
@bp.odeint(method='rkf45',
adaptive=True, # active the "adaptive" option
tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
# should provide one more argument "dt" when using the adaptive rk method
x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)
hist_x.append(x.value)
hist_y.append(y.value)
hist_z.append(z.value)
hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()


3. Exponential Euler methods for ODEs#
Finally, BrainPy provides Exponential integrators for ODEs. For you ODE systems, we highly recommend you to use Exponential Euler methods. Exponential Euler method provided in BrainPy uses automatic differentiation to find linear part.
Methods |
keywords |
---|---|
Exponential Euler |
exp_euler |
Let’s take a linear system as the theoretical demonstration,
the exponential Euler schema is given by:
As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.
However, using Exponential Euler method requires us to write each derivative function separately. Otherwise, the automatic differentiation will lead to wrong results.
Interestingly, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. You will find that by using the Exponential Euler method, the numerical step can be greatly enlarged to save the computation time.
Iext=10.; ENa=50.; EK=-77.; EL=-54.387
C=1.0; gNa=120.; gK=36.; gL=0.03
def dm(m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
beta = 4.0 * bm.exp(-(V + 65) / 18)
dmdt = alpha * (1 - m) - beta * m
return dmdt
def dh(h, t, V):
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
dhdt = alpha * (1 - h) - beta * h
return dhdt
def dn(n, t, V):
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
beta = 0.125 * bm.exp(-(V + 65) / 80)
dndt = alpha * (1 - n) - beta * n
return dndt
def dV(V, t, m, h, n, Iext):
I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
I_K = (gK * n ** 4.0) * (V - EK)
I_leak = gL * (V - EL)
dVdt = (- I_Na - I_K - I_leak + Iext) / C
return dVdt
Although we define HH differential equations as separable functions, relying on brainpy.JointEq, we can numerically integrate these equations jointly.
hh_derivative = bp.JointEq([dV, dm, dh, dn])
def run(method, Iext=10., dt=0.1):
integral = bp.odeint(hh_derivative, method=method)
runner = bp.integrators.IntegratorRunner(
integral,
monitors=list('Vmhn'),
inits=[0., 0., 0., 0.],
args=dict(Iext=Iext),
dt=dt
)
runner.run(100.)
plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()
Euler Method: not able to complete the integral when the time step is a bit larger
run('euler', Iext=10, dt=0.02)

run('euler', Iext=10, dt=0.1)

RK4 Method: better than the Euler method, but still requires the times step to be small
run('rk4', Iext=10, dt=0.1)

run('rk4', Iext=10, dt=0.2)

Exponential Euler Method: allows larger time step and generates accurate results
run('exp_euler', Iext=10, dt=0.2)

Numerical Solvers for Stochastic Differential Equations#
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
import brainpy as bp
import matplotlib.pyplot as plt
%matplotlib inline
How to define SDE functions?#
For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by
where \(X_t = X(t)\) is the realization of a stochastic process or random variable, \(f(X_t, t)\) is the drift coefficient, \(g(X_t, t)\) denotes the diffusion coefficient, the stochastic process \(W_t\) is called Wiener process.
For this SDE system, we can define two Python funtions \(f\) and \(g\) to represent it.
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df
Same with the ODE functions, the arguments before \(t\) denotes the random variables, while the arguments defined after \(t\) represents the parameters. For the SDE function with scalar noise, the size of the return data \(dg\) and \(df\) should be the same. For example, \(df \in R^d, dg \in R^d\).
However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:
For such \(m\)-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of \(dg\) has one more dimension. For example, \(df \in R^{d}, dg \in R^{m \times d}\).
How to define the numerical integration for SDEs?#
Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get
In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “\(\circ\)” to distinct it from the former Ito integral.
The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as
In the stochastic integral of the Ito SDE, \(\lambda=0\), thus \(\tau_k=t_k\);
In the definition of the Stratonovich integral, \(\lambda=0.5\), thus \(\tau_k=(t_{k+1} + t_{k}) / 2\).
In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type
in decorator bp.sdeint
. intg_type
can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type
keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.
Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
Or, it can be expressed as:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(d,)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part, g=g_part, method='milstein')
However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:
def g_part(x, t, p1, p2):
dg = g(x, t, p2)
return dg # shape=(m, d)
def f_part(x, t, p1, p2):
df = f(x, t, p1)
return df # shape=(d,)
integral = bp.sdeint(f=f_part,
g=g_part,
method='milstein',
intg_type=bp.integrators.STRA_SDE,
wiener_type=bp.integrators.VECTOR_WIENER)
Example: Noisy Lorenz system#
Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:
sigma = 10; beta = 8/3;
rho = 28; p = 0.1
def lorenz_g(x, y, z, t):
return p * x, p * y, p * z
def lorenz_f(x, y, z, t):
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return dx, dy, dz
lorenz = bp.sdeint(f=lorenz_f,
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER)
To run this integrator, we use brainpy.integrators.IntegratorRunner, which can JIT compile the model to gain impressive speed.
runner = bp.integrators.IntegratorRunner(
lorenz,
monitors=['x', 'y', 'z'],
inits=[1., 1., 1.],
dt=0.001
)
runner.run(50.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We can also rewrite the above differential equation as a JointEq of separable equations, so that it can be applied to Exponential Euler method.
dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
lorenz_f = bp.JointEq(dx, dy, dz)
lorenz = bp.sdeint(f=lorenz_f,
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER,
method='exp_euler')
runner = bp.integrators.IntegratorRunner(
lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.], dt=0.001
)
runner.run(50.)
plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()

Supported SDE Numerical Methods#
BrainPy
provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.
Methods |
Keywords |
Ito SDE support |
Stratonovich SDE support |
Scalar Wiener support |
Vector Wiener support |
---|---|---|---|---|---|
srk1w1_scalar |
Yes |
Yes |
|||
srk2w1_scalar |
Yes |
Yes |
|||
KlPl_scalar |
Yes |
Yes |
|||
euler |
Yes |
Yes |
Yes |
Yes |
|
heun |
Yes |
Yes |
Yes |
||
milstein |
Yes |
Yes |
Yes |
Yes |
|
exp_euler |
Yes |
Yes |
Yes |
Numerical Solvers for Fractional Differential Equations#
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:
Grünwald-Letnikov derivative
Caputo fractional derivative
See Fractional calculus - Wikipedia for more details.
Methods for Caputo FDEs#
For a given fractional differential equation
where the fractional order \(0<\alpha\le 1\). BrainPy provides two kinds of methods:
Euler method -
brainpy.fde.CaputoEuler
L1 schema integration -
brainpy.fde.CaputoL1Schema
brainpy.fde.CaputoEuler
#
brainpy.fed.CaputoEuler
provides one-step Euler method for integrating Caputo fractional differential equations.
Given a fractional-order Qi chaotic system
we can solve the equation system by:
a, b, c = 35, 8/3, 80
def qi_system(x, y, z, t):
dx = -a*x + a*y + y*z
dy = c*x - y - x*z
dz = -b*z + x*y
return dx, dy, dz
dt = 0.001
duration = 50
inits = [0.1, 0.2, 0.3]
# The numerical integration of FDE need to know all
# history information, therefore, we need provide
# the overall simulation time "num_step" to save
# all history values.
integrator = bp.fde.CaputoEuler(qi_system,
alpha=0.98, # fractional order
num_step=int(duration/dt),
inits=inits)
runner = bp.integrators.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

brainpy.fde.CaputoL1Schema
#
brainpy.fed.CaputoL1Schema
is another commonly used method to integrate Caputo derivative equations. Let’s try it with a fractional-order Lorenz system, which is given by:
a, b, c = 10, 28, 8 / 3
def lorenz_system(x, y, z, t):
dx = a * (y - x)
dy = x * (b - z) - y
dz = x * y - c * z
return dx, dy, dz
dt = 0.001
duration = 50
inits = [1, 2, 3]
integrator = bp.fde.CaputoL1Schema(lorenz_system,
alpha=0.99, # fractional order
num_step=int(duration/dt),
inits=inits)
runner = bp.integrators.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

Methods for Grünwald-Letnikov FDEs#
Grünwald-Letnikov FDE is another commonly-used type in neuroscience. Here, we provide a efficient computation method according to the short-memory principle in Grünwald-Letnikov method.
brainpy.fde.GLShortMemory
#
brainpy.fde.GLShortMemory
is highly efficient, because it does not require infinity memory length for numerical solution. Due to the decay property of the coefficients, brainpy.fde.GLShortMemory
implements a limited memory length to reduce the computational time. Specifically, it only relies on the memory window of num_memory
length. With the increasing width of memory window, the accuracy of numerical approximation will increase.
Here, we demonstrate it by using a fractional-order Chua system, which is defined as
a, b, c = 10.725, 10.593, 0.268
m0, m1 = -1.1726, -0.7872
def chua_system(x, y, z, t):
f = m1*x+0.5*(m0-m1)*(abs(x+1)-abs(x-1))
dx = a*(y-x-f)
dy = x - y + z
dz = -b*y - c*z
return dx, dy, dz
dt = 0.001
duration = 200
inits = [0.2, -0.1, 0.1]
integrator = bp.fde.GLShortMemory(chua_system,
alpha=[0.93, 0.99, 0.92],
num_step=1000,
inits=inits)
runner = bp.integrators.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()

plt.figure(figsize=(10, 8))
plt.plot(runner.mon.y, runner.mon.z)
plt.show()

Actually, the coefficient used in brainpy.fde.GLWithMemory
can be inspected through:
plt.figure(figsize=(10, 6))
coef = integrator.binomial_coef
alphas = bm.as_numpy(integrator.alpha)
plt.subplot(211)
for i in range(3):
plt.plot(coef[:, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.subplot(212)
for i in range(3):
plt.plot(coef[:10, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.show()

As you see, the coefficients decay very quickly!
Further reading#
More examples of how to use numerical solvers of fractional differential equations defined in BrainPy, please see:
Numerical Solvers for Delay Differential Equations#
In real world systems, delay is very often encountered in many practical systems, such as automatic control, biology, economics and long transmission lines. The delayed differential equation (DDEs) is used to describe these dynamical systems.
Delay differential equations (DDEs) are a type of differential equation in which the derivative at a certain time is given in terms of the values of the function at previous times.
Let’s take delay ODEs as the example. The simplest constant delay equations have the form
where the time delays (lags) \(\tau_j\) are positive constants.
For neutral type DDE delays appear in derivative terms,
More generally, state dependent delays may depend on the solution, that is \(\tau_i = \tau_i (t,y(t))\).
In BrainPy, we support delay differential equations based on delay variables. Specifically, for state-dependent delays, we have:
brainpy.math.TimeDelay
brainpy.math.LengthDelay
For neutral-type delays, we use:
brainpy.math.NeuTimeDelay
brainpy.math.NeuLenDelay
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
Delay variables#
For an ODE system, the numerical methods need to know its initial condition \(y(t_0)=y_0\) and its derivative rule \(y'(t_0) = y'_0\). However, for DDEs, it is not enough to give a set of initial values for the function and its derivatives at \(t_0\), but one must give a set of functions to provide the historical values for \(t_0 - max(\tau) \leq t \leq t_0\).
Therefore, you need some delay variables to wrap the variable delays. brainpy.math.TimeDelay
can be used to define delay variables which depend on states, and brainpy.math.NeuTimeDelay
is used to define delay variables which depend on the derivative.
d = bm.TimeDelay(bm.zeros(2), delay_len=10, dt=1, t0=0, before_t0=lambda t: t)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
d(0.)
DeviceArray([0., 0.], dtype=float32)
d(-0.5)
DeviceArray([-0.5, -0.5], dtype=float32)
Request a time beyond \((max\_delay, t_0)\) will cause an error.
try:
d(0.1)
except Exception as e:
print(e)
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x0000025616D24BE0> threw exception
!!! Error in TimeDelay:
The request time should be less than the current time 0. But we got 0.10000000149011612 > 0.
!!! Error in TimeDelay:
The request time should be less than the current time 0. But we got 0.10000000149011612 > 0
Delay ODEs#
Here we illustrate how to make numerical integration of delay ODEs with several examples. Before that, we define a general function to simulate a delay ODE function.
def delay_odeint(duration, eq, args=None, inits=None,
state_delays=None, neutral_delays=None,
monitors=('x',), method='euler', dt=0.1):
# define integrators of ODEs based on `brainpy.odeint`
dde = bp.odeint(eq,
state_delays=state_delays,
neutral_delays=neutral_delays,
method=method)
# define IntegratorRunner
runner = bp.integrators.IntegratorRunner(dde,
args=args,
monitors=monitors,
dt=dt,
inits=inits)
runner.run(duration)
return runner.mon
Example #1: First-order DDE with one constant delay and a constant initial history function#
Let the following DDE be given:
where the delay is 1 s. the example compares the solutions of three different cases using three different constant history functions:
Case #1: \(\phi(t)=-1\)
Case #2: \(\phi(t)=0\)
Cas3 #3: \(\phi(t)=1\)
def equation(x, t, xdelay):
return -xdelay(t-1)
case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round')
case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=0., interp_method='round')
case3_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=1., interp_method='round')
case1 = delay_odeint(20., equation, args={'xdelay': case1_delay},
state_delays={'x': case1_delay}) # delay for variable "x"
case2 = delay_odeint(20., equation, args={'xdelay': case2_delay}, state_delays={'x': case2_delay})
case3 = delay_odeint(20., equation, args={'xdelay': case3_delay}, state_delays={'x': case3_delay})
fig, axs = plt.subplots(3, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1)$")
axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')
axs[1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=0$')
axs[2].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[2].set_title('$ihf(t)=1$')
plt.show()

Example #2: First-order DDE with one constant delay and a non constant initial history function#
Let the following DDE be given:
where the delay is 2 s; the example compares the solutions of four different cases using two different non constant history functions and two different intervals of \(t\):
Case #1: \(\phi(t)=e^{-t} - 1, t \in [0, 4]\)
Case #2: \(\phi(t)=e^{t} - 1, t \in [0, 4]\)
Case #3: \(\phi(t)=e^{-t} - 1, t \in [0, 60]\)
Case #4: \(\phi(t)=e^{t} - 1, t \in [0, 60]\)
def eq(x, t, xdelay):
return -xdelay(t-2)
delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')
delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t)-1, dt=0.01, interp_method='round')
delay3 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')
delay4 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t)-1, dt=0.01, interp_method='round')
case1 = delay_odeint(4., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(4., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
case3 = delay_odeint(60., eq, args={'xdelay': delay3}, state_delays={'x': delay3}, dt=0.01)
case4 = delay_odeint(60., eq, args={'xdelay': delay4}, state_delays={'x': delay4}, dt=0.01)
fig, axs = plt.subplots(2, 2)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-2)$")
axs[0, 0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 4]$')
axs[0, 1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[0, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 4]$')
axs[1, 0].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[1, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 60]$')
axs[1, 1].plot(case4.ts, case4.x, color='red', linewidth=1)
axs[1, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 60]$')
plt.show()

Example #3: First-order DDE with two constant delays and a constant initial history function#
Let the following DDE be given:
where the delays are two and are both constants equal to 1s and 2s respectively; The initial historical function is also constant and is \(\phi(t)=1\).
def eq(x, t):
return -delay(t-1) + 0.3*delay(t-2)
delay = bm.TimeDelay(bm.ones(1), 2., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(10., eq, inits=[1.], state_delays={'x': delay}, dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1) + 0.3\ y(t-2)$")
axs.plot(mon.ts, mon.x, color='red', linewidth=1)
axs.set_title('$ihf(t)=1$')
plt.show()

Example #4: System of two first-order DDEs with one constant delay and two constant initial history functions#
Let the following system of DDEs be given:
where the delay is only one, constant and equal to 0.5 s and the initial historical functions are also constant; for what we said at the beginning of the post these must be two, in fact being the order of the system of first degree you need one for each unknown and they are: \(y_1(t)=1, y_2(t)=-1\).
def eq(x, y, t):
dx = x * ydelay(t-0.5)
dy = y * xdelay(t-0.5)
return dx, dy
xdelay = bm.TimeDelay(bm.ones(1), 0.5, before_t0=1., dt=0.01, interp_method='round')
ydelay = bm.TimeDelay(-bm.ones(1), 0.5, before_t0=-1., dt=0.01, interp_method='round')
mon = delay_odeint(3., eq, inits=[1., -1], state_delays={'x': xdelay, 'y': ydelay},
dt=0.01, monitors=['x', 'y'])
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$x'(t)=x(t) y(t-d); y'(t)=y(t) x(t-d)$")
axs.plot(mon.ts, mon.x.flatten(), color='red', linewidth=1)
axs.plot(mon.ts, mon.y.flatten(), color='blue', linewidth=1)
axs.set_title('$ihf_x(t)=1; ihf_y(t)=-1; d=0.5$')
plt.show()

Example #5: Second-order DDE with one constant delay and two constant initial history functions#
Let the following DDE be given:
where the delay is only one, constant and equal to 1 s. Since the DDE is second order, in that the second derivative of the unknown function appears, the historical functions must be two, one to give the values of the unknown \(y(t)\) for \(t <= 0\), and one and one to provide the value of the first derivative \(y'(t)\) also for \(t <= 0\).
In this example they are the following two constant functions: \(y(t)=1, y'(t)=0\).
Due to the properties of the second-order equations, the given DDE is equivalent to the following system of first-order equations:
and so the implementation falls into the case of the previous example of systems of first-order equations.
def eq(x, y, t):
dx = y
dy = -y - 2*x - 0.5*xdelay(t-1)
return dx, dy
xdelay = bm.TimeDelay(bm.ones(1), 1., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(16., eq, inits=[1., 0.], state_delays={'x': xdelay}, monitors=['x', 'y'], dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y''(t)=-y'(t) - 2 y(t) - 0.5 y(t-1)$")
axs.plot(mon.ts, mon.x[:,0], color='red', linewidth=1)
axs.plot(mon.ts, mon.y[:,0], color='green', linewidth=1)
axs.set_title('$ih \, f_y(t)=1; ihf\,dy/dt(t)=0$')
plt.show()

Example #6: First-order DDE with one non constant delay and a constant initial history function#
Let the following DDE be given:
where the delay is not constant and is given by the function \(\mathrm{delay}(y, t)=|\frac{1}{10} t y(\frac{1}{10} t)|\), the example compares the solutions of two different cases using two different constant history functions:
Case #1: \(\phi(t)=-1\)
Case #2: \(\phi(t)=1\)
def eq(x, t, xdelay):
delay = abs(t*xdelay(t - 0.9 * t)/10) # a tensor with (1,)
delay = delay[0]
return xdelay(t-delay)
Note
Note here we do not kwon the maximum lenght of the delay. Therefore, we can declare a fixed length delay variable with the delay_len
equal to or even bigger than the running duration.
delay1 = bm.TimeDelay(bm.ones(1), 30., before_t0=-1, dt=0.01)
delay2 = bm.TimeDelay(-bm.ones(1), 30., before_t0=1, dt=0.01)
case1 = delay_odeint(30., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(30., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
fig, axs = plt.subplots(2, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=y(t-delay(y, t))$")
axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')
axs[1].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=1$')
plt.show()

Delay SDEs#
Save as delay ODEs, state-dependent delay variables can be appended into state_delay
argument in brainpy.sdeint
function.
delay = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')
f = lambda x, t, xdelay: -xdelay(t-2)
g = lambda x, t, *args: 0.01
dt = 0.01
integral = bp.sdeint(f, g, state_delays={'x': delay})
runner= bp.integrators.IntegratorRunner(integral,
monitors=['x'],
args={'xdelay': delay},
dt=dt)
runner.run(100.)
plt.plot(runner.mon.ts, runner.mon.x)
plt.show()

Delay FDEs#
Fractional order delayed differential equations as the generalization of the delayed differential equations, provide more freedom when we’re describing these systems, let’s see how we can use BrainPy to accelerate the simulation of fractional order delayed differential equations.
The fractional delayed differential equations has the general form:
Lemmings’ population cycle#
The fractional order version of the four-year life cycle of a population of lemmings is given by
dt=0.05
delay = bm.TimeDelay(bm.asarray([19.00001]), 0.74, before_t0=19., dt=dt)
f = lambda y, t: 3.5 * y * (1 - delay(t-0.74)/19)
integral = bp.fde.GLShortMemory(f, alpha=0.97, inits=[19.00001], num_step=500,
state_delays={'y': delay})
runner = bp.integrators.IntegratorRunner(integral,
inits=bm.asarray([19.00001]),
monitors=['y'],
fun_monitors={'y(t-0.74)': lambda t, _: delay(t-0.74)},
dyn_vars=delay.vars(),
dt=dt)
runner.run(100.)
plt.plot(runner.mon['y'], runner.mon['y(t-0.74)'])
plt.xlabel('y(t)')
plt.ylabel('y(t-0.74)')
plt.show()

Time delay Chen system#
Time delay Chen system as a famous chaotic system with time delay, has important applications in many fields.
dt = 0.001
tau = 0.009
xdelay = bm.TimeDelay(bm.asarray([0.2]), tau, dt=dt)
zdelay = bm.TimeDelay(bm.asarray([0.5]), tau, dt=dt)
def derivative(x, y, z, t):
a=35; b=3; c=27
dx = a*(y-xdelay(t-tau))
dy = (c-a)*xdelay(t-tau)-x*z+c*y
dz = x*y-b*zdelay(t-tau)
return dx, dy, dz
integral = bp.fde.GLShortMemory(derivative,
alpha=0.94,
inits=[0.2, 0., 0.5],
num_step=500,
state_delays={'x': xdelay, 'z': zdelay})
runner = bp.integrators.IntegratorRunner(integral,
inits=[0.2, 0., 0.5],
monitors=['x', 'y', 'z'],
dyn_vars=xdelay.vars() + zdelay.vars(),
dt=dt)
runner.run(100.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
plt.show()

Enzyme kinetics#
Let’s see a more complex example of the fractional order version of enzyme kinetics with an inhibitor molecule:
dt = 0.01
tau = 4.
delay = bm.TimeDelay(bm.asarray([20.]), tau, before_t0=20, dt=dt)
def derivative(a, b, c, d, t):
da = 10.5-a/(1+ 0.0005 * delay(t-tau)**3)
db = a/(1+0.0005 * delay(t-tau)**3)-b
dc = b-c
dd = c-0.5*d
return da, db, dc, dd
integral = bp.fde.GLShortMemory(derivative,
alpha=0.95,
inits=[60, 10, 10, 20],
num_step=500,
state_delays={'d': delay})
runner = bp.integrators.IntegratorRunner(integral,
inits=[60, 10, 10, 20],
monitors=list('abcd'),
dyn_vars=delay.vars(),
dt=dt)
runner.run(200.)
plt.plot(runner.mon.ts, runner.mon.a, label='a')
plt.plot(runner.mon.ts, runner.mon.b, label='b')
plt.plot(runner.mon.ts, runner.mon.c, label='c')
plt.plot(runner.mon.ts, runner.mon.d, label='d')
plt.legend()
plt.xlabel('Time [ms]')
plt.show()

Fractional matrix delayed differential equations#
BrainPy is also capable of solving fractional matrix delayed differential equations:
Here \(x(t)\) is vector of states of the system, \(c(t)\) is a known function of disturbance.
We explain the detailed usage by using an example:
With initial condition:
dt = 0.01
tau = 3.1416
f = lambda t: bm.asarray([bm.sin(t)*bm.cos(t),
bm.sin(t)*bm.cos(t),
bm.cos(t)**2-bm.sin(t)**2,
bm.cos(t)**2-bm.sin(t)**2])
delay = bm.TimeDelay(f(0.), tau, before_t0=f, dt=dt)
A = bm.asarray([[0, 0, 1, 0], [0, 0, 0, 1], [0, -2, 0, 0], [-2, 0, 0, 0]])
B = bm.asarray([[0, 0, 0, 0], [0, 0, 0, 0], [-2, 0, 0, 0], [0, -2, 0, 0]])
c = bm.asarray([0, 0, 0, 0])
derivative = lambda x, t: A @ x + B @ delay(t - tau) + c
integral = bp.fde.GLShortMemory(derivative,
alpha=0.4,
inits=[f(0.)],
num_step=500,
state_delays={'x': delay})
runner = bp.integrators.IntegratorRunner(integral,
inits=[f(0.)],
monitors=['x'],
dyn_vars=delay.vars(),
dt=dt)
runner.run(100.)
plt.plot(runner.mon.x[:, 0], runner.mon.x[:, 2])
plt.xlabel('x0')
plt.ylabel('x2')
plt.show()

Acknowledgement#
This tutorial is highly inspired from the work of Ettore Messina [1] and of Qingyu Qu [2].
[1] Computational Mindset by Ettore Messina, Solving delay differential equations using numerical methods in Python
Joint Differential Equations#
In a dynamical system, there may be multiple variables that change dynamically over time. Sometimes these variables are interconnected, and updating one variable requires others as the input. For example, in the widely known Hodgkin–Huxley model, the variables \(V\), \(m\), \(h\), and \(n\) are updated synchronously and interdependently (please refer to Building Neuron Modelsfor details). To achieve higher integral accuracy, it is recommended to use brainpy.JointEq
to jointly solving interconnected differential equations.
import brainpy as bp
brainpy.JointEq
#
brainpy.JointEq
is used to merge individual but interconnected differential equations into a single joint equation. For example, below are the two differential equations of the Izhikevich model:
a, b = 0.02, 0.20
dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
du = lambda u, t, V: a * (b * V - u)
Where updating \(V\) requires \(u\) as the input, and updating \(u\) requires \(V\) as the input. The joint equation can be defined as:
joint_eq = bp.JointEq(eqs=(dV, du))
brainpy.JointEq
receives only one argument named eqs
, which can be a list or tuple containing multiple differential equations. Then it can be packed into a numarical integrator that solves the equation with a specified method, just as what can be done to any individual differential equation.
itg = bp.odeint(joint_eq, method='rk2')
There are several requirements for defining a joint equation:
Every individual differential equation should follow the format of defining a ODE or SDE funtion in BrainPy. For example, the arguments before
t
denote the dynamical variables and arguments aftert
denote the parameters.The same variable in different equations should have the same name. Different variables should named differently.
Note that brainpy.JointEq
supports make nested JointEq
, which means the instance of JointEq
can be an element to compose a new JointEq
.
Why use brainpy.JointEq
?#
Users may be confused with the function of brainpy.JointEq
, because multiple differential equations can be written in a single function:
def diff(V, u, t, Iext):
dV = 0.04 * V * V + 5 * V + 140 - u + Iext
du = a * (b * V - u)
return dV, du
itg_V_u = bp.odeint(diff, method='rk2')
or simply packed into interators separately:
int_V = bp.odeint(dV, method='rk2')
int_u = bp.odeint(du, method='rk2')
To illusrate the difference between joint and separate differential equations, let’s dive into the differential codes of these two types of equations.
If we make numerical solver for each derivative function, they will be solved independently:
import brainpy as bp
bp.odeint(dV, method='rk2', show_code=True)
def brainpy_itg_of_ode6(V, t, u, Iext, dt=0.1):
dV_k1 = f(V, t, u, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
return V_new
{'f': <function <lambda> at 0x0000022948DD6A60>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x229660543a0>
As is shown in the output code, the variable \(V\) is integrated twice by the RK2 method. For the second differential value dV_k2
, the updated value of \(V\) (k2_V_arg
) and original \(u\) are used to calculate the differential value. This will generate a tiny error, since the values of \(V\) and \(u\) are taken at different times.
To eliminate this error, the differential equation of \(V\) and \(u\) should be solved jointly through brainpy.JointEq
:
eq = bp.JointEq(eqs=(dV, du))
bp.odeint(eq, method='rk2', show_code=True)
def brainpy_itg_of_ode12_joint_eq(V, u, t, Iext, dt=0.1):
dV_k1, du_k1 = f(V, u, t, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_u_arg = u + dt * du_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
return V_new, u_new
{'f': <brainpy.integrators.joint_eq.JointEq object at 0x0000022967EC0C40>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x22967ec0160>
It is shown in this output code that second differential values of \(v\) and \(u\) are calculated by using the updated values (k2_V_arg
and k2_u_arg
) at the same time. This will result in a more accurate integral.
The figure below compares the simulation results of the Izhikevich model using joint and separate differential equations (\(dt = 0.2 ms\)). It is shown that as the simulation time increases, the integral error becomes greater.

Synaptic Connections#
Synaptic connections is an essential part for building a neural dynamic system. BrainPy provides several commonly used connection methods in the brainpy.connect module (which can be accessed by the shortcut bp.conn
) that can help users to easily construct many types of synaptic connection, inclulding built-in and self-customized connectors.
An Overview of BrainPy Connectors#
Here we provide an overview of BrainPy connectors.
Base class: bp.conn.Connector
#
The base class of connectors is brainpy.connect.Connector
. All connectors, built-in or customized, should inherit from the Connector class.
Two subclasses: TwoEndConnector
and OneEndConnector
#
There are two classes inheriting from the base class bp.conn.Connector
:
bp.conn.TwoEndConnector: a connector to build synaptic connections between two neuron groups.
bp.conn.OneEndConnector: a connector to build synaptic connections within a population of neurons.
Users can click the link of each class above to look through the API documentation.
Connector.__init__()#
All connectors need to be initialized first. For each built-in connector, users need to pass in the corresponding parameters for initialization. For details, please see the specific conector type below.
Connector.__call__()#
After initialization, users should call the connector and pass in parameters depending on specific connection types:
TwoEndConnector
: It has two input parameterspre_size
andpost_size
, each representing the size of the pre- and post-synaptic neuron group. It will result in a connection matrix with the shape of (pre_num, post_num).OneEndConnector
: It has only one parameterpre_size
which represent the size of the neuron group. It will result in a connection matrix with the shape of (pre_num, pre_num).
The __call__
function returns the class itself.
Connector.build_conn()#
Users can customize the connection in build_conn()
function. Notice there are three connection types users can provide:
Connection Types |
Definition |
---|---|
‘mat’ |
Dense conncetion, including a connection matrix. |
‘ij’ |
Index projection, including a pre-neuron index vector and a post-neuron index vector. |
‘csr’ |
Sparse connection, including a index vector and a indptr vector. |
Return type can be either a dict
or a tuple
. Here are two examples of how to return your connection data:
Example 1:
def build_conn(self):
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return dict(csr=(ind, indptr), mat=None, ij=None)
Example 2:
def build_conn(self):
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return 'csr', (ind, indptr)
After creating the synaptic connection, users can use the require()
method to access some useful properties of the connection.
Connector.require()#
This method returns the connection properties required by users. The connection properties are elaborated in the following sections in detail. Here is a brief summary of the connection properties users can require.
Connection properties |
Structure |
Definition |
---|---|---|
|
2-D array (matrix) |
Dense connection matrix |
|
1-D array (vector) |
Indices of the pre-synaptic neuron group |
|
1-D array (vector) |
Indices of the post-synaptic neuron group |
|
tuple (vector, vector) |
The post-synaptic neuron indices and the corresponding pre-synaptic neuron pointers |
|
tuple (vector, vector) |
The pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers |
|
tuple (vector, vector) |
The synapse indices sorted by pre-synaptic neurons and corresponding pre-synaptic neuron pointers |
|
tuple (vector, vector) |
The synapse indices sorted by post-synaptic neurons and corresponding post-synaptic neuron pointers |
Users can implement this method by following sentence:
pre_ids, post_ids, pre2post, conn_mat = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
Note
Note that this method can return multiple connection properties.
Connection Properties#
There are multiple connection properties that can be required by users.
1. conn_mat
#
The matrix-based synaptic connection is one of the most intuitive ways to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat')
. Each connection matrix is an array with the shape of \((n_{pre}, n_{post})\):

2. pre_ids
and post_ids
#
Using vectors to store the connection between neuron groups is a much more efficient way to reduce memory when the connection matrix is sparse. For the connction matrix conn_mat
defined above, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids.

In this way, we only need two vectors (pre_ids
and post_ids
) to store the synaptic connection. syn_id
in the figure indicates the indices of each neuron pair, i.e. each synapse.
3. pre2post
and post2pre
#
Another two synaptic structures are pre2post
and post2pre
. They establish the mapping between the pre- and post-synaptic neurons.
pre2post
is a tuple containing two vectors, one of which is the post-synaptic neuron indices and the other is the corresponding pre-synaptic neuron pointers. For example, the following figure shows the indices of the pre-synaptic neurons and the post-synaptic neurons to which the pre-synaptic neurons project:

To record the connection, firstly the post_ids are concatenated as a single vector call the post-synaptic index vector (indices). Because the post-synaptic neuron indices have been sorted by the pre-synaptic neuron indices, it is sufficient to record only the starting position of each pre-synaptic neuron index. Therefore, the pre-synaptic neuron indices and the end of the last pre-synaptic neuron index together make up the pre-synaptic index pointer vector (indptr), which is illustrated in the figure below.

The post-synaptic neuron indices to which pre-synaptic neuron \(i\) projects can be obtained by array slicing:
indices[indptr[i], indptr[i+1]]
Similarly, post2pre
is a 2-element tuple containing the pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers. Taking the connection in the illutration aobve as an example, the post-synaptic neuron indices and the pre-synaptic neuron indices to which the post-synaptic neurons project is shown as:

The pre-synaptic index vector (indices) and the post-synaptic index pointer vector (indptr) are listed below:

When the connection is sparse, pre2post
(or post2pre
) is a very efficient way to store the connection, since the lengths of the two vectors in the tuple are \(n_{synapse}\) and \(n_{pre}\) (\(n_{post}\)), respectively.
4. pre2syn
and post2syn
#
The last two properties are pre2syn
and post2syn
that record pre- and post-synaptic projection, respectively.
For pre2syn
, similar to pre2post
and post2pre
, there is a synapse index vector and a pre-synaptic index pointer vector that refers to the starting position of each pre-synaptic neuron index at the synapse index vector.
Below is the same example identifying the connection by pre-synaptic neuron indices and the synapses belonging to them.

For better understanding, The synapse indices, pre- and post-synaptic neuron indices are shown as below:

The pre-synaptic index pointer vector is computed in the same way as in pre2post
:

Similarly, post2syn
is a also tuple containing the synapse neuron indices and the corresponding post-synaptic neuron pointers.

The only different from pre2syn
is that the synapse indices is (most of the time) originally sorted by pre-synaptic neurons, but when computing post2syn
, synapses should be sorted by post-synaptic neuron indices:

The synapse index vector (the first row) and the post-synaptic index pointer vector (the last row) are listed below:

import brainpy as bp
bp.math.set_platform('cpu')
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
Built-in regular connections#
brainpy.connect.One2One#
The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
conn(pre_size=10, post_size=10)
<brainpy.connect.regular_conn.One2One at 0x12d27d9a0>
where pre_size
denotes the size of the pre-synaptic neuron group, post_size
denotes the size of the post-synaptic neuron group.
Note that parameter size
can be int, tuple of int or list of int where each element represent each dimension of neuron group.
In One2One connection, particularly, pre_size
and post_size
must be the same.
Class One2One
is inherited from TwoEndConnector
. Users can use method require
or requires
to get specific connection properties.
Here is an example:
size = 5
conn = bp.connect.One2One()(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
pre_ids: JaxArray([0, 1, 2, 3, 4], dtype=uint32)
post_ids: JaxArray([0, 1, 2, 3, 4], dtype=uint32)
pre2post: (JaxArray([0, 1, 2, 3, 4], dtype=uint32), JaxArray([0, 1, 2, 3, 4, 5], dtype=uint32))
brainpy.connect.All2All#
All neurons of the post-synaptic population form connections with all
neurons of the pre-synaptic population (dense connectivity). Users can
choose whether connect the neurons at the same position
(include_self=True or False
).

conn = bp.connect.All2All(include_self=False)
conn(pre_size=size, post_size=size)
<brainpy.connect.regular_conn.All2All at 0x12d18ac10>
Class All2All
is inherited from TwoEndConnector
. Users can use method require
or requires
to get specific connection properties.
Here is an example:
conn = bp.connect.All2All(include_self=False)(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')
print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
print('conn_mat:', res[3])
pre_ids: JaxArray([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], dtype=uint32)
post_ids: JaxArray([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], dtype=uint32)
pre2post: (JaxArray([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], dtype=uint32), JaxArray([ 0, 4, 8, 12, 16, 20], dtype=uint32))
conn_mat: JaxArray([[False, True, True, True, True],
[ True, False, True, True, True],
[ True, True, False, True, True],
[ True, True, True, False, True],
[ True, True, True, True, False]], dtype=bool)
brainpy.connect.GridFour#
GridFour
is the four nearest neighbors connection. Each neuron connect to its
nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridFour at 0x12d0356a0>
Class GridFour
is inherited from OneEndConnector
, therefore there is only one parameter pre_size
representing the size of neuron group, which should be two-dimensional geometry.
Here is an example:
size = (4, 4)
conn = bp.connect.GridFour(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')
print('pre_ids', res[0])
pre_ids JaxArray([ 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5,
5, 5, 6, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9,
9, 10, 10, 10, 10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14,
14, 15, 15], dtype=uint32)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.GridEight#
GridEight
is eight nearest neighbors connection. Each neuron connect to its
nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridEight at 0x143df5c10>
Class GridEight
is inherited from GridN
, which will be introduced as followed.
Here is an example:
size = (4, 4)
conn = bp.connect.GridEight(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')
print('pre_ids', res[0])
pre_ids JaxArray([ 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3,
3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6,
6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8,
8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 13, 13, 13, 13,
13, 14, 14, 14, 14, 14, 15, 15, 15], dtype=uint32)
Take the central point (id = 4) as an example, its neighbors are all the other point except itself.
Therefore, its row in conn_mat
has True
for all values except itself.
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.GridN#
GridN
is also a nearest neighbors connection. Each neuron connect to its
nearest \((2N+1) \cdot (2N+1)\) neurons (if including itself).

Here are some examples to fully understand GridN
. It is slightly different from GridEight
: GridEight
is equivalent to GridN
when N = 1.
When N = 1: \(\begin{bmatrix} x & x & x\\ x & I & x\\ x & x & x \end{bmatrix}\)
When N = 2: \( \begin{bmatrix} x & x & x & x & x\\ x & x & x & x & x\\ x & x & I & x & x\\ x & x & x & x & x\\ x & x & x & x & x \end{bmatrix} \)
conn = bp.connect.GridN(N=2, include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridN at 0x143ebdf70>
Here is an example:
size = (4, 4)
conn = bp.connect.GridN(N=1, include_self=False)(pre_size=size)
res = conn.require('conn_mat')
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res)
nx.draw(G, with_labels=True)
plt.show()

Built-in random connections#
brainpy.connect.FixedProb#
For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

Class brainpy.connect.FixedProb
is inherited from TwoEndConnector
, and it receives three settings:
prob
: Fixed probability for connection with a pre-synaptic neuron for each post-synaptic neuron.include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=134)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[False, True, False, False],
[False, False, False, False],
[ True, True, False, True],
[ True, False, True, False]], dtype=bool)
brainpy.connect.FixedPreNum#
Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

Class brainpy.connect.FixedPreNum
is inherited from TwoEndConnector
, and it receives three settings:
num
: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedPreNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[ True, False, False, False],
[False, True, False, False],
[ True, True, True, True],
[False, False, True, True]], dtype=bool)
brainpy.connect.FixedPostNum#
Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

Class brainpy.connect.FixedPostNum
is inherited from TwoEndConnector
, and it receives three settings:
num
: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).include_self
: Whether connect to inself.seed
: Seed the random generator.
And there are two parameters passed in for calling instance of class: pre_size
and post_size
.
conn = bp.connect.FixedPostNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[ True, False, True, False],
[False, True, True, False],
[False, False, True, True],
[False, False, True, True]], dtype=bool)
brainpy.connect.GaussianProb#
Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.
Specifically,
where \((x, y)\) is the position of the pre-synaptic neuron and \((x_c,y_c)\) is the position of the post-synaptic neuron.
For example, in a \(30 \textrm{x} 30\) two-dimensional networks, when \(\beta = \frac{1}{2\sigma^2} = 0.1\), the connection pattern is shown as the follows:

GaussianProb
is inherited from OneEndConnector
, and it receives four settings:
sigma
: (float) Width of the Gaussian function.encoding_values
: (optional, list, tuple, int, float) The value ranges to encode for neurons at each axis.periodic_boundary
: (bool) Whether the neuron encode the value space with the periodic boundary.normalize
: (bool) Whether normalize the connection probability.include_self
: (bool) Whether create the conn at the same position.seed
: (bool) The random seed.
conn = bp.connect.GaussianProb(sigma=2, periodic_boundary=True, normalize=True, include_self=True, seed=21)
conn(pre_size=10)
conn.require('conn_mat')
JaxArray([[ True, True, False, True, False, False, False, False,
True, True],
[ True, True, True, True, False, False, False, False,
False, True],
[ True, True, True, True, False, False, False, False,
False, True],
[False, True, True, True, True, False, False, False,
False, False],
[False, False, False, True, True, True, False, False,
False, False],
[False, False, True, False, True, True, True, True,
False, False],
[False, False, False, True, False, True, True, True,
False, False],
[ True, False, False, False, False, True, True, True,
False, True],
[False, True, False, False, False, False, True, True,
True, True],
[ True, False, True, False, False, True, False, True,
True, True]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.SmallWorld#
SmallWorld
is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:
[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.
Currently, SmallWorld
only support a one-dimensional network with the ring structure. It receives four settings:
num_neighbor
: the number of the nearest neighbors to connect.prob
: the probability of rewiring each edge.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.include_self
: whether allow to connect to itself.
conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, True, False, False, False, False, True, False,
True, False],
[ True, False, True, True, False, True, False, False,
False, True],
[False, True, False, True, True, False, False, False,
False, False],
[False, True, True, False, True, True, False, False,
False, False],
[False, False, True, True, False, True, True, False,
False, False],
[False, True, False, True, True, False, False, True,
False, True],
[ True, False, False, False, True, False, False, True,
True, False],
[False, False, False, False, False, True, True, False,
True, True],
[ True, False, False, False, False, False, True, True,
False, True],
[False, True, False, False, False, True, False, True,
True, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.ScaleFreeBA#
ScaleFreeBA
is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA
receives the following settings:
m
: Number of edges to attach from a new node to existing nodes.directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.
conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False, False, False, True, True, True,
False, True],
[False, False, False, False, False, True, True, True,
True, True],
[False, False, False, False, False, True, True, False,
False, False],
[False, False, False, False, False, True, False, False,
False, True],
[False, False, False, False, False, True, True, True,
True, False],
[ True, True, True, True, True, False, True, True,
True, True],
[ True, True, True, False, True, True, False, True,
True, True],
[ True, True, False, False, True, True, True, False,
True, False],
[False, True, False, False, True, True, True, True,
False, False],
[ True, True, False, True, False, True, True, False,
False, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.ScaleFreeBADual#
ScaleFreeBADual
is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:
p
: The probability of attaching \(m_1\) edges (as opposed to \(m_2\) edges).m1
: Number of edges to attach from a new node to existing nodes with probability \(p\).m2
: Number of edges to attach from a new node to existing nodes with probability \(1-p\).directed
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.
conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False, False, False, True, False, True,
False, True],
[False, False, False, False, False, True, False, True,
True, False],
[False, False, False, False, False, True, True, False,
True, True],
[False, False, False, False, False, True, False, False,
False, False],
[False, False, False, False, False, True, True, True,
True, False],
[ True, True, True, True, True, False, True, True,
True, True],
[False, False, True, False, True, True, False, True,
True, False],
[ True, True, False, False, True, True, True, False,
False, True],
[False, True, True, False, True, True, True, False,
False, True],
[ True, False, True, False, False, True, False, True,
True, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

brainpy.connect.PowerLaw#
PowerLaw
is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:
m
: the number of random edges to add for each new nodep
: Probability of adding a triangle after adding a random edgedirected
: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.seed
: Indicator of random number generation state.
[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.
conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False, True, False, True, False, True,
True, False],
[False, False, False, True, True, False, False, True,
False, False],
[False, False, False, True, True, True, True, False,
True, True],
[ True, True, True, False, True, True, True, False,
False, True],
[False, True, True, True, False, False, False, True,
True, False],
[ True, False, True, True, False, False, True, False,
False, False],
[False, False, True, True, False, True, False, False,
False, True],
[ True, True, False, False, True, False, False, False,
False, False],
[ True, False, True, False, True, False, False, False,
False, False],
[False, False, True, True, False, False, True, False,
False, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()

Encapsulate your existing connections#
BrainPy also allows users to encapsulate existing connections with convenient class interfaces. Users can provide connection types as:
Index projection;
Dense matrix;
Sparse matrix.
Then users should provide pre_size
and post_size
information in order to instantiate the connection. In such a way, based on the following connection classes, users can generate any other synaptic structures (such like pre2post
, pre2syn
, conn_mat
, etc.) easily.
bp.conn.IJConn
#
Here, let’s take a simple connection as an example. In this example, we create a connection which receives users’ handful index projection by using bp.conn.IJConn
.
pre_list = np.array([0, 1, 2])
post_list = np.array(([0, 0, 0]))
conn = bp.conn.IJConn(i=pre_list, j=post_list)
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
JaxArray([[ True, False, False],
[ True, False, False],
[ True, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 0, 0], dtype=uint32), JaxArray([0, 1, 2, 3, 3, 3], dtype=uint32))
conn.requires('pre2syn')
(JaxArray([0, 1, 2], dtype=uint32), JaxArray([0, 1, 2, 3, 3, 3], dtype=uint32))
bp.conn.MatConn
#
In next example, we create a connection which receives user’s handful dense connection matrix by using bp.conn.MatConn
.
bp.math.random.seed(123)
conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_))
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
JaxArray([[ True, True, False],
[False, True, True],
[ True, False, True],
[False, False, True],
[ True, False, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 1, 1, 2, 0, 2, 2, 0], dtype=uint32),
JaxArray([0, 2, 4, 6, 7, 8], dtype=uint32))
conn.require('pre2syn')
(JaxArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=uint32),
JaxArray([0, 2, 4, 6, 7, 8], dtype=uint32))
bp.conn.SparseMatConn
#
In last example, we create a connection which receives user’s handful sparse connection matrix by using bp.conn.sparseMatConn
from scipy.sparse import csr_matrix
conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
sparse_mat = csr_matrix(conn_mat)
conn = bp.conn.SparseMatConn(sparse_mat)
conn = conn(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1])
conn.requires('conn_mat')
JaxArray([[ True, False, True],
[ True, False, True],
[ True, False, False],
[False, True, True],
[False, True, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 2, 0, 2, 0, 1, 2, 1], dtype=uint32),
JaxArray([0, 2, 4, 5, 7, 8], dtype=uint32))
conn.requires('post2syn')
(JaxArray([0, 2, 4, 5, 7, 1, 3, 6], dtype=uint32),
JaxArray([0, 3, 5, 8], dtype=uint32))
Using NetworkX to provide connections and pass into Connector
#
NetworkX is a Python package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks.
Users can design their own complex netork by using NetworkX
.
import networkx as nx
G = nx.Graph()
By definition, a Graph
is a collection of nodes (vertices) along with identified pairs of nodes (called edges, links, etc).
To learn more about NetowrkX
, please check the official documentation: NetworkX tutorial
Using class brainpy.connect.MatConn
to construct connections is recommended here.
Dense adjacency matrix: a two-dimensional ndarray.
Here gives an example to illustrate how to transform a random graph into your synaptic connections by using dense adjacency matrix.
G = nx.fast_gnp_random_graph(5, 0.5) # initialize a random graph G
B = nx.adjacency_matrix(G)
A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G
print('dense adjacency matrix:')
print(A)
nx.draw(G, with_labels=True)
plt.show()
dense adjacency matrix:
[[0 0 1 1 0]
[0 0 0 1 1]
[1 0 0 0 1]
[1 1 0 0 1]
[0 1 1 1 0]]

Users can use class MatConn
inherited from TwoEndConnector
to construct connections. A dense adjacency matrix should be passed in when initializing MatConn
class.
Note that when calling the instance of the class, users should pass in two parameters: pre_size
and post_size
.
In this case, users can use the shape of dense adjacency matrix as the parameters.
conn = bp.connect.MatConn(A)(pre_size=A.shape[0], post_size=A.shape[1])
res = conn.require('conn_mat')
print(res)
JaxArray([[False, False, True, True, False],
[False, False, False, True, True],
[ True, False, False, False, True],
[ True, True, False, False, True],
[False, True, True, True, False]], dtype=bool)
Customize your connections#
BrainPy allows users to customize their connections. The following requirements should be satisfied:
Your connection class should inherit from
brainpy.connect.TwoEndConnector
orbrainpy.connect.OneEndConnector
.__init__
function should be implemented and essential parameters should be initialized.Users should also overwrite
build_conn()
function to describe how to build your connection.
Let’s take an example to illustrate the details of customization.
class FixedProb(bp.connect.TwoEndConnector):
"""Connect the post-synaptic neurons with fixed probability.
Parameters
----------
prob : float
The conn probability.
include_self : bool
Whether to create (i, i) connection.
seed : optional, int
Seed the random generator.
"""
def __init__(self, prob, include_self=True, seed=None):
super(FixedProb, self).__init__()
assert 0. <= prob <= 1.
self.prob = prob
self.include_self = include_self
self.seed = seed
self.rng = np.random.RandomState(seed=seed)
def build_conn(self):
ind = []
count = np.zeros(self.pre_num, dtype=np.uint32)
def _random_prob_conn(rng, pre_i, num_post, prob, include_self):
p = rng.random(num_post) <= prob
if (not include_self) and pre_i < num_post:
p[pre_i] = False
conn_j = np.asarray(np.where(p)[0], dtype=np.uint32)
return conn_j
for i in range(self.pre_num):
posts = _random_prob_conn(self.rng, pre_i=i, num_post=self.post_num,
prob=self.prob, include_self=self.include_self)
ind.append(posts)
count[i] = len(posts)
ind = np.concatenate(ind)
indptr = np.concatenate(([0], count)).cumsum()
return 'csr', (ind, indptr)
Then users can initialize the your own connections as below:
conn = FixedProb(prob=0.5, include_self=True)(pre_size=5, post_size=5)
conn.require('conn_mat')
JaxArray([[False, True, False, True, False],
[False, False, True, False, True],
[ True, False, True, False, True],
[False, False, False, False, True],
[ True, True, False, False, True]], dtype=bool)
Synaptic Weights#
In a brain model, synaptic weights, the strength of the connection between presynaptic and postsynaptic neurons, are crucial to the dynamics of the model. In this section, we will illutrate how to build synaptic weights in a synapse model.
import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt
bp.math.set_platform('cpu')
Creating Static Weights#
Some computational models focus on the network structure and its influence on network dynamics, thus not modeling neural plasticity for simplicity. In this condition, synaptic weights are fixed and do not change in simulation. They can be stored as a scalar, a matrix or a vector depending on the connection strength and density.
1. Storing weights with a scalar#
If all synaptic weights are designed to be the same, the single weight value can be stored as a scalar in the synpase model to save memory space.
weight = 1.
The weight
can be stored in a synapse model. When updating the synapse, this weight
is assigned to all synapses by scalar multiplication.
2. Storing weights with a matrix#
When the synaptic connection is dense and the synapses are assigned with different weights, weights can be stored in a matrix \(W\), where \(W(i, j)\) refers to the weight of presynaptic neuron \(i\) to postsynaptic neuron \(j\).
BrainPy provides brainpy.initialize.Initializer
(or brainpy.init
for short) for weight initialization as a matrix. The tutorial of brainpy.init.Initializer
is introduced later.
For example, a weight matrix can be constructed using brainpy.init.Uniform
, which initializes weights with a random distribution:
pre_size = (4, 4)
post_size = (3, 3)
uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init((pre_size, post_size))
print('shape of weights: {}'.format(weights.shape))
shape of weights: (16, 9)
Then, the weights can be assigned to a group of connections with the same shape. For example, an all-to-all connection matrix can be obtained by brainpy.conn.All2All()
whose tutorial is contained in Synaptic Connections:
conn = bp.conn.All2All()
conn(pre_size, post_size)
conn_mat = conn.requires('conn_mat') # request the connection matrix
Therefore, weights[i, j]
refers to the weight of connection (i, j)
.
i, j = (2, 3)
print('whether (i, j) is connected: {}'.format(conn_mat[i, j]))
print('synaptic weights of (i, j): {}'.format(weights[i, j]))
whether (i, j) is connected: True
synaptic weights of (i, j): 0.48069751262664795
3. Storing weights with a vector#
When the synaptic connection is sparse, using a matrix to store synaptic weights is too wasteful. Instead, the weights can be stored in a vector which has the same length as the synaptic connections.

Weights can be assigned to the corresponding synapses as long as the they are aligned with each other.
size = 5
conn = bp.conn.One2One()
conn(size, size)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')
print('presynaptic neuron ids: {}'.format(pre_ids))
print('postsynaptic neuron ids: {}'.format(post_ids))
print('synapse ids: {}'.format(bm.arange(size)))
presynaptic neuron ids: [0 1 2 3 4]
postsynaptic neuron ids: [0 1 2 3 4]
synapse ids: [0 1 2 3 4]
The weight vector is aligned with the synapse vector, i.e. synapse ids :
weights = bm.random.uniform(0, 2, size)
for i in range(size):
print('weight of synapse {}: {}'.format(i, weights[i]))
weight of synapse 0: 1.1737329959869385
weight of synapse 1: 1.2609302997589111
weight of synapse 2: 1.4217698574066162
weight of synapse 3: 1.781663179397583
weight of synapse 4: 1.1460866928100586
Conversion from a weight matrix to a weight vector#
For users who would like to obtain the weight vector from the weight matrix, they can first build a connection according to the non-zero elements in the weight matrix and then slice the weight matrix according to the connection:
weight_mat = np.array([[1., 1.5, 0., 0.5], [0., 2.5, 0., 0.], [2., 0., 3, 0.]])
print('weight matrix: \n{}'.format(weight_mat))
conn = bp.conn.MatConn(weight_mat)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')
weight_vec = weight_mat[pre_ids, post_ids]
print('weight_vector: \n{}'.format(weight_vec))
weight matrix:
[[1. 1.5 0. 0.5]
[0. 2.5 0. 0. ]
[2. 0. 3. 0. ]]
weight_vector:
[1. 1.5 0.5 2.5 2. 3. ]
Note
However, it is not recommended to use this function when the connection is sparse and of a large scale, because generating the weight matrix will take up too much space.
Creating Dynamic Weights#
Sometimes users may want to realize neural plasticity in a brain model, which requires the synaptic weights to change during simulation. In this condition, weights should be considered as variables, thus defined as brainpy.math.Variable
. If it is packed in a synapse model, weight updating should be realized in the update(_t, _dt)
function of the synapse model.
weights = bm.Variable(bm.ones(10))
weights
Variable([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)
Built-in Weight Initializers#
Base Class: bp.init.Initializer#
The base class of weight initializers are brainpy.initialize.Initializer
, which can be accessed by the shortcut bp.init
. All initializers, built-in or costumized, should inherit the Initializer
class.
Weight initialization is implemented in the __call__
function, so that it can be realized automatically when the initializer is called. The __call__
function has a shape
parameter that has different meanings in the following two superclasses and returns a weight matrix.
Superclass 1: bp.init.InterLayerInitializer#
The InterLayerInitializer
is an abstract subclass of Initializer
. Its subclasses initialize the weights between two fully connected layers. The shape
parameter of the __call__
function should be a 2-element tuple \((m, n)\), which refers to the number of presynaptic neurons \(m\) and of postsynaptic neurons \(n\). The output of the __call__
function is a bp.math.ndarray
with the shape of \((m, n)\), where the value at \((i, j)\) is the initialized weight of the presynaptic neuron \(i\) to postsynpatic neuron \(j\).
Superclass 2: bp.init.IntraLayerInitializer#
The IntraLayerInitializer
is also an abstract subclass of Initializer
. Its subclasses initialize the weights within a single layer. The shape
parameter of the __call__
function refers to the the structure of the neural population \((n_1, n_2, ..., n_d)\). The __call__
function returns a 2-D bp.math.ndarray
with the shape of \((\prod_{k=1}^d n_k, \prod_{k=1}^d n_k)\). In the 2-D array, the value at \((i, j)\) is the initialized weight of neuron \(i\) to neuron \(j\) of the flattened neural sequence.
1. Built-In Regular Initializers#
Regular initializers all belong to InterLayerInitializer
and initialize the connection weights between two layers with a regular pattern. There are ZeroInit
, OneInit
, and Identity
initializers in built-in regular initializers. Here we show how to use the OneInit
initializer. The remaining two classes are used in a similar way.
# visualization
def mat_visualize(matrix, cmap=plt.cm.get_cmap('coolwarm')):
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
# 'OneInit' initializes all the weights with the same value
shape = (5, 6)
one_init = bp.init.OneInit(value=2.5)
weights = one_init(shape)
print(weights)
JaxArray([[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[2.5, 2.5, 2.5, 2.5, 2.5, 2.5]], dtype=float32)
2. Built-In Random Initializers#
Random initializers all belong to InterLayerInitializer
and initialize the connection weights between two layers with a random distribution. There are Normal
, Uniform
, Orthogonal
and other initializers in built-in regular initializers. Here we show how to use the Normal
and Uniform
initializers.
bp.init.Normal
This initializer initializes the weights with a normal distribution. The variance of the distribution changes according to the scale
parameter. In the following example, 10 presynaptic neurons are fully connected to 20 postsynaptic neurons with random weight values:
shape = (10, 20)
normal_init = bp.init.Normal(scale=1.0)
weights = normal_init(shape)
mat_visualize(weights)

bp.init.Uniform
This initializer resembles brainpy.init.Normal
but initializes the weights with a uniform distribution.
uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init(shape)
mat_visualize(weights)

3. Built-In Decay Initializers#
Decay initializers all belong to IntraLayerInitializer
and initialize the connection weights within a layer with a decay function according to the neural distance. There are GaussianDecay
and DOGDecay
initializers in built-in decay initializers. Below are examples of how to use them.
brainpy.training.initialize.GaussianDecay
This initializer creates a Gaussian connectivity pattern within a population of neurons, where the weights decay with a gaussian function. Specifically, for any pair of neurons \( (i, j) \), the weight is computed as
where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).
The example below is a neural population with the size of \( 5 \times 5 \). Note that this shape is the structure of the target neural population, not the size of presynaptic and postsynaptic neurons.
size = (5, 5)
gaussian_init = bp.init.GaussianDecay(sigma=2., max_w=10., include_self=True)
weights = gaussian_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (25, 25)
Self-connections are created if include_self=True
. The connection weights of neuron \(i\) with others are stored in row \(i\) of weights
. For instance, the connection weights of neuron(1, 2) to other neurons are stored in weights[7]
(\(5 \times 1 +2 = 7\)). After reshaping, the weights are:
mat_visualize(weights[0].reshape(size), cmap=plt.cm.get_cmap('Reds'))

brainpy.training.initialize.DOGDecay
This initializer creates a Difference-Of-Gaussian (DOG) connectivity pattern within a population of neurons. Specifically, for the given pair of neurons \( (i, j) \), the weight between them is computed as
where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 10 \times 12 \):
size = (10, 12)
dog_init = bp.init.DOGDecay(sigmas=(1., 3.), max_ws=(10., 5.), min_w=0.1, include_self=True)
weights = dog_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (120, 120)
Weights smaller than min_w
will not be created. min_w
\( = 0.005 \times min( \) max_ws
\( ) \) if it is not assigned with a value.
The organization of weights
is similar to that in the GaussianDecay
initializer. For instance, the connection weights of neuron (3, 4) to other neurons after reshaping are shown as below:
mat_visualize(weights[3*12+4].reshape(size), cmap=plt.cm.get_cmap('Reds'))

Customizing your initializers#
BrainPy also allows users to customize the weight initializers of their own. When customizing a initializer, users should follow the instructions below:
Your initializer should inherit
brainpy.initialize.Initializer
.Override the
__call__
funtion, to which theshape
parameter should be given.
Here is an example of creating an inter-layer initializer that initialize the weights as follows:
class LinearDecay(bp.init.InterLayerInitializer):
def __init__(self, max_w, sigma=1.):
self.max_w = max_w
self.sigma = sigma
def __call__(self, shape, dtype=None):
mat = bp.math.zeros(shape, dtype=dtype)
n_pre, n_post = shape
seq = np.arange(n_pre)
current_w = self.max_w
for i in range(max(n_pre, n_post)):
if current_w <= 0:
break
seq_plus = ((seq + i) >= 0) & ((seq + i) < n_post)
seq_minus = ((seq - i) >= 0) & ((seq - i) < n_post)
mat[seq[seq_plus], (seq + i)[seq_plus]] = current_w
mat[seq[seq_minus], (seq - i)[seq_minus]] = current_w
current_w -= self.sigma
return mat
shape = (10, 15)
lin_init = LinearDecay(max_w=5, sigma=1.)
weights = lin_init(shape)
mat_visualize(weights, cmap=plt.cm.get_cmap('Reds'))

Note
Note that customized initializers, or brainpy.init.Initializer
, is not limited to returning a matrix. Although currently all the built-in initializers use matrix to store weights, they can also be designed to return a vector to store synaptic weights.
Gradient Descent Optimizers#
Gradient descent is one of the most popular optimization methods. At present, gradient descent optimizers, combined with the loss function, are the key to machine learning, especially deep learning. In this section, we are going to understand:
how to use optimizers in BrainPy?
how to customize your own optimizer?
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
import matplotlib.pyplot as plt
Optimizers in BrainPy#
The basic optimizer class in BrainPy is brainpy.optimizers.Optimizer
, which inludes the following optimizers:
SGD
Momentum
Nesterov momentum
Adagrad
Adadelta
RMSProp
Adam
All supported optimizers can be inspected through the brainpy.math.optimizers APIs.
Generally, an optimizer initialization receives the learning rate lr
, the trainable variables train_vars
, and other hyperparameters for the specific optimizer.
lr
can be a float, or an instance ofbrainpy.optim.Scheduler
.train_vars
should be a dict of Variable.
Here we launch a SGD optimizer.
a = bm.Variable(bm.ones((5, 4)))
b = bm.Variable(bm.zeros((3, 3)))
op = bp.optim.SGD(lr=0.001, train_vars={'a': a, 'b': b})
When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update()
method.
op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})
print('a:', a)
print('b:', b)
a: Variable([[0.9998637 , 0.99998003, 0.99998164, 0.99945515],
[0.99987453, 0.9994551 , 0.9992173 , 0.99931335],
[0.99995905, 0.999395 , 0.9999578 , 0.9992222 ],
[0.9990663 , 0.9991484 , 0.99950355, 0.9991641 ],
[0.999546 , 0.99927944, 0.9995042 , 0.9993433 ]], dtype=float32)
b: Variable([[-2.2797585e-06, -8.2317321e-04, -3.6020565e-04],
[-4.7648646e-04, -4.1223815e-04, -4.2962815e-04],
[-6.4605832e-05, -9.0033154e-04, -2.7676427e-04]], dtype=float32)
You can process the gradients before applying them. For example, we clip the graidents by the maximum L2-norm.
grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}
grads_pre
{'a': JaxArray([[0.34498155, 0.5973849 , 0.3798045 , 0.9029752 ],
[0.48180568, 0.85238564, 0.9597529 , 0.19037902],
[0.09431374, 0.03752387, 0.05545044, 0.18625283],
[0.4265157 , 0.13127708, 0.44978166, 0.6449729 ],
[0.5855427 , 0.02250564, 0.9523196 , 0.317971 ]], dtype=float32),
'b': JaxArray([[0.4190004 , 0.7033491 , 0.13393831],
[0.11366987, 0.33574808, 0.37153232],
[0.39718974, 0.25615263, 0.08950627]], dtype=float32)}
grads_post = bm.clip_by_norm(grads_pre, 1.)
grads_post
{'a': JaxArray([[0.14619358, 0.2531551 , 0.16095057, 0.38265574],
[0.20417583, 0.3612173 , 0.40671656, 0.08067733],
[0.03996754, 0.01590157, 0.02349835, 0.07892876],
[0.18074548, 0.05563157, 0.19060494, 0.27332157],
[0.24813668, 0.00953726, 0.40356654, 0.13474725]], dtype=float32),
'b': JaxArray([[0.38518783, 0.6465901 , 0.12312973],
[0.10449692, 0.30865383, 0.34155035],
[0.36513725, 0.23548158, 0.08228327]], dtype=float32)}
op.update(grads_post)
print('a:', a)
print('b:', b)
a: Variable([[0.9997175 , 0.9997269 , 0.9998207 , 0.9990725 ],
[0.9996703 , 0.9990939 , 0.9988105 , 0.99923265],
[0.99991906, 0.9993791 , 0.9999343 , 0.9991433 ],
[0.9988856 , 0.9990928 , 0.99931294, 0.99889076],
[0.99929786, 0.9992699 , 0.9991006 , 0.9992085 ]], dtype=float32)
b: Variable([[-0.00038747, -0.00146976, -0.00048334],
[-0.00058098, -0.00072089, -0.00077118],
[-0.00042974, -0.00113581, -0.00035905]], dtype=float32)
Note
Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, your dyn_vars
in bm.jit()
should include variables in Optimzier.vars()
.
op.vars() # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable([2], dtype=int32)}
bp.optim.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Momentum has velocity variables
{'Constant1.step': Variable([0], dtype=int32),
'Momentum0.a_v': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Momentum0.b_v': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)}
bp.optim.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars() # Adam has more variables
{'Constant2.step': Variable([0], dtype=int32),
'Adam0.a_m': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Adam0.b_m': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32),
'Adam0.a_v': Variable([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
'Adam0.b_v': Variable([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)}
Creating A Self-Customized Optimizer#
To create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer
class and override the following methods:
__init__()
: init function that receives the learning rate (lr
) and trainable variables (train_vars
). Do not forget to register your dynamical changed variables intoimplicit_vars
.update(grads)
: update function that computes the updated parameters.
The general structure is shown below:
class CustomizeOp(bp.optim.Optimizer):
def __init__(self, lr, train_vars, *params, **other_params):
super(CustomizeOp, self).__init__(lr, train_vars)
# customize your initialization
def update(self, grads):
# customize your update logic
pass
Schedulers#
Scheduler seeks to adjust the learning rate during training through reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.
Here we set up an exponential decay scheduler, in which the learning rate will decay exponentially along the training step.
sc = bp.optim.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
plt.plot(steps, rates)
plt.xlabel('Train Step')
plt.ylabel('Learning Rate')
plt.show()
steps = bm.arange(1000)
rates = sc(steps)
show(steps, rates)

After Optimizer initialization, the learning rate self.lr
will always be an instance of bm.optimizers.Scheduler
. A scalar float learning rate initialization will result in a Constant
scheduler.
op.lr
Constant(0.001)
One can get the current learning rate value by calling Scheduler.__call__(i=None)
.
If
i
is not provided, the learning rate value will be evaluated at the built-in trainingstep
.Otherwise, the learning rate value will be evaluated at the given step
i
.
op.lr()
0.001
In BrainPy, several commonly used learning rate schedulers are used:
Constant
ExponentialDecay
InverseTimeDecay
PolynomialDecay
PiecewiseConstant
For more details, please see the brainpy.math.optimizers APIs.
# InverseTimeDecay scheduler
rates = bp.optim.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)

# PolynomialDecay scheduler
rates = bp.optim.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)

Creating a Self-Customized Scheduler#
If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler
class and override the following methods:
__init__()
: the init function.__call__(i=None)
: the learning rate value evalution.
class CustomizeScheduler(bp.optim.Scheduler):
def __init__(self, lr, *params, **other_params):
super(CustomizeScheduler, self).__init__(lr)
# customize your initialization
def __call__(self, i=None):
# customize your update logic
pass
Runners#
Runners for Dynamical Systems#
The convenient simulation interfaces for dynamical systems in BrainPy are implemented in brainpy.simulation.runner
. Currently, we implement two kinds of runner: DSRunner
and ReportRunner
. They have their respective advantages.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
Initializing a runner#
Generally, we can initialize a runner with the format of:
SomeRunner(target=instance_of_dynamical_system,
inputs=inputs_for_target_variables,
monitors=interested_variables_to_monitor,
dyn_vars=dynamical_changed_variables,
jit=enable_jit_or_not)
In which
target
specifies the model to be simulated. It must an instance of brainpy.DynamicalSystem.monitors
is used to define target variables in the model. During the simulation, the history values of the monitored variables will be recorded. More information can be found in the Monitors tutorial.inputs
is used to define the input operations for specific variables. It will be expanded in the Inputs tutorial.dyn_vars
is used to specify all the dynamically changed variables used in thetarget
model.jit
determines whether to use JIT compilation during the simulation.
Here we define an E/I balanced network as the simulation model.
class EINet(bp.Network):
def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = bp.dyn.LIF(num_exc, **pars, method=method)
I = bp.dyn.LIF(num_inh, **pars, method=method)
E.V[:] = bm.random.randn(num_exc) * 2 - 55.
I.V[:] = bm.random.randn(num_inh) * 2 - 55.
# synapses
E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
E2I = bp.dyn.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
I2E = bp.dyn.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
I2I = bp.dyn.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
Then we will wrap it in different runners for dynamic simulation.
brainpy.DSRunner
#
brainpy.DSRunner
aims to provide model simulation with an outstanding performance. It takes advantage of the structural loop primitive to lower the model onto the XLA devices.
net = EINet()
runner = bp.DSRunner(net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner.run(100.)
1.2974658012390137
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)

Note that if the parameter jit
is set to True
, then all the variables will be JIT compiled and thus the system cannot be debugged by Python debugging tools. For debugging, users can set jit=False
.
brainpy.ReportRunner
#
brainpy.ReportRunner
aims to provide a Pythonic interface for model debugging. Users can use the standard Python debugging tools when simulating the model with ReportRunner
.
The drawback of the brainpy.ReportRunner
is that it is relatively slow. It iterates the loop along times during the simulation.
net = EINet()
runner = bp.ReportRunner(net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)],
jit=True)
runner.run(100.)
3.402564764022827
We can see from the output that the time spent for simulation through ReportRunner
is longer than that through DSRunner
.
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)

Runners for Neural Network Training#
Inputs#
In this section, we are going to talk about stimulus inputs.
import brainpy as bp
import brainpy.math as bm
Inputs in brainpy.dyn.DSRunner
#
In brain dynamics simulation, various inpus are usually given to different units of the dynamical system. In BrainPy, inputs
can be specified to runners for dynamical systems. The aim of inputs
is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording.
inputs
should have the format like (target, value, [type, operation])
, where
target
is the target variable to inject the input.value
is the input value. It can be a scalar, a tensor, or a iterable object/function.type
is the type of the input value. It support two types of input:fix
anditer
. The first one means that the data is static; the second one denotes the data can be iterable, no matter whether the input value is a tensor or a function. Theiter
type must be explicitly stated.operation
is the input operation on the target variable. It should be set as one of{ + , - , * , / , = }
, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated asval = val + input
.
Users can also give multiple inputs for different target variables, like:
inputs=[(target1, value1, [type1, op1]),
(target2, value2, [type2, op2]),
... ]
The mechanism of inputs
is the same as monitors
. BrainPy finds the target variables for input operations through the absolute or relative path.
Input construction functions#
Like electrophysiological experiments, model simulation also needs various kind of inputs. BrainPy provide several convenient input functions to help users construct input currents.
1. brainpy.inputs.section_input()
#
brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input()
(see below).
Sometimes, we need input currents with different values in different periods. For example, if you want to get an input that is 0 in the first 100 ms, 1 in the next 300 ms, and 0 again from the last 100 ms, you can define:
current1, duration = bp.inputs.section_input(values=[0, 1., 0.],
durations=[100, 300, 100],
return_length=True,
dt=0.1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Where values
receive a list/arrray of the current values in each section and durations
receives a list/array of the duration of each section. The function returns a tensor as the current, the length of which is duration
\(/\mathrm{d}t\) (if not specified, \(\mathrm{d}t=0.1 \mathrm{ms}\)). We can visualize the current input by:
import numpy as np
import matplotlib.pyplot as plt
def show(current, duration, title):
ts = np.arange(0, duration, bm.get_dt())
plt.plot(ts, current)
plt.title(title)
plt.xlabel('Time [ms]')
plt.ylabel('Current Value')
plt.show()
show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')

2. brainpy.inputs.constant_input()
#
brainpy.inputs.constant_input() function helps users to format constant currents in several periods.
We can generate the above input current with constant_input()
by:
current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])
Where each tuple in the list contains the value and duration of the input in this section.
show(current2, duration, '[(0, 100), (1, 300), (0, 100)]')

3. brainpy.inputs.spike_input()
#
brainpy.inputs.spike_input() constructs an input containing a series of short-time spikes. It receives the following settings:
sp_times
: The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.sp_lens
: The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same withsp_times
.sp_sizes
: The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same withsp_times
.duration
: The total current duration.dt
: The time step precision. The default is None (will be initialized as the defaultdt
step).
For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, where each spike lasts 1 ms and the average value for each spike is 0.5, then you can define the current by:
current3 = bp.inputs.spike_input(
sp_times=[10, 20, 30, 200, 300],
sp_lens=1., # can be a list to specify the spike length at each point
sp_sizes=0.5, # can be a list to specify the spike current size at each point
duration=400.)
show(current3, 400, 'Spike Input Example')

4. brainpy.inputs.ramp_input()
#
brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:
c_start
: The minimum (or maximum) current size.c_end
: The maximum (or minimum) current size.duration
: The total duration.t_start
: The ramped current start time-point.t_end
: The ramped current end time-point. Default is the None.dt
: The current precision.
We illustrate the usage of brainpy.inputs.ramp_input()
by two examples.
In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (500 ms).
duration = 500
current4 = bp.inputs.ramp_input(0, 1, duration)
show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
r'$t_{start}$=0, $t_{end}$=None' % (duration))

In the second example, we increase the current size from 0. to 1. from the 100 ms to 400 ms.
duration, t_start, t_end = 500, 100, 400
current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)
show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))

5. brainpy.inputs.wiener_process
#
brainpy.inputs.wiener_process() is used to generate the basic Wiener process \(dW\), i.e. random numbers drawn from \(N(0, \sqrt{dt})\).
duration = 200
current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.)
show(current6, duration, 'Wiener Process')

6. brainpy.inputs.ou_process
#
brainpy.inputs.ou_process() is used to generate the noise time series from Ornstein-Uhlenback process \(\dot{x} = (\mu - x)/\tau \cdot dt + \sigma\cdot dW\).
duration = 200
current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.)
show(current7, duration, 'Ornstein-Uhlenbeck Process')

7. brainpy.inputs.sinusoidal_input
#
brainpy.inputs.sinusoidal_input() can help to generate sinusoidal inputs.
duration = 2000
current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration, t_start=100., )
show(current8, duration, 'Sinusoidal Input')

8. brainpy.inputs.square_input
#
brainpy.inputs.square_input() can help to generate oscillatory square inputs.
duration = 2000
current9 = bp.inputs.square_input(amplitude=1., frequency=2.0,
duration=duration, t_start=100)
show(current9, duration, 'Square Input')

More complex inputs#
Because the current input is stored as a tensor, a complex input can be realized by the combination of several simple currents.
show(current1 + current5, 500, 'A Complex Current Input')

General properties of input functions#
1. Every input function receives a dt
specification.
If dt
is not provided, input functions will use the default dt
in the whole BrainPy system.
I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1)
I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01)
print('I1.shape: {}'.format(I1.shape))
print('I2.shape: {}'.format(I2.shape))
I1.shape: (600,)
I2.shape: (6000,)
2. All input functions can automatically broadcast the current shapes if they are heterogenous among different periods.
For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example:
current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
durations=[100, 300, 100])
current.shape
(5000, 3, 10)
Monitors#
BrainPy has a systematic naming system. Any model in BrainPy have a unique name. Thus, nodes, integrators, and variables can be easily accessed in a huge network. Based on this naming system, BrainPy provides a set of convenient monitoring supports. In this section, we are going to talk about this.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
bp.math.set_dt(0.02)
import numpy as np
import matplotlib.pyplot as plt
Initializing Monitors in a Runner#
In BrainPy, any instance of brainpy.Runner
has a built-in monitor. Users can set up a monitor when initializing a runner.
For example, if we want to simulate a Hodgkin-Hoxley (HH) model and monitor its membrane potential \(V\) and the spikes it generates:
HH = bp.dyn.HH
model = HH(1)
After defining a HH neuron, we can add monitors while setting up the runner. When specifying the monitors
parameter, a monitor, which is an instance of brainpy.Monitor
, will be initialized. The first method to initialize a monitor is through a list/tuple of strings:
# set up a monitor using a list of str
runner1 = bp.StructRunner(model,
monitors=['V', 'spike'],
inputs=('input', 10))
type(runner1.mon)
brainpy.running.monitor.Monitor
where the string 'V'
and 'spike'
corresponds to the name of the variables in the HH model:
model.V, model.spike
(Variable(DeviceArray([0.], dtype=float32)),
Variable(DeviceArray([False], dtype=bool)))
Besides using a list/tuple of strings, users can also directly use the Monitor
class to initialize a monitor:
# set up a monitor using brainpy.Monitor
runner2 = bp.StructRunner(model, monitors=bp.Monitor(variables=['V', 'spike']))
Once we call the runner with a given time duration, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using .mon.[variable_name]
. The default history times .mon.ts
will also be generated after the model finishes its running. Let’s see an example.
runner1(100.)
bp.visualize.line_plot(runner1.mon.ts, runner1.mon.V, show=True)

The monitor in runner1
has recorded the evolution of V
. Therefore, it can be accessed by runner1.mon.V
or equivalently runner1.mon['V']
. Similarly, the recorded trajectory of variable spike
can also be obtained through runner1.mon.spike
.
runner1.mon.spike
array([[False],
[False],
[ True],
...,
[False],
[False],
[False]])
Where True
indicates a spike is generated at this time step.
The Mechanism of monitors
#
No matter we use a list/tuple or instantiate a Monitor
class to generate a monitor, we specify the target variables by strings of their names. How does brainpy.Monitor
find the target variables through these strings?
Actually, BrainPy first tries to find the target variables in the simulated model by the relative path. If the variables are not found, BrainPy checks whether they can be accessed by the absolute path. If they not found again, an error will be raised.
net = bp.Network(HH(size=10, name='X'),
HH(size=20, name='Y'),
HH(size=30))
# it's ok
bp.StructRunner(net, monitors=['X.V', 'Y.spike']).build_monitors()
<function brainpy.dyn.runners.ds_runner.DSRunner.build_monitors.<locals>.func(_t, _dt)>
In the above net
, there are HH
instances named as “X” and “Y”. Therefore, trying to monitor “X.V” and “Y.spike” is successful.
However, in the following example, the node named with “Z” is not accessible in the generated net
, and the monitoring setup fails.
z = HH(size=30, name='Z')
net = bp.Network(HH(size=10), HH(size=20))
# node "Z" can not be accessed in the simulation target 'net'
try:
bp.StructRunner(net, monitors=['Z.V']).build_monitors()
except Exception as e:
print(type(e).__name__, ":", e)
RunningError : Cannot find target Z.V in monitor of <brainpy.compact.brainobjects.Network object at 0x000002827DA49E80>, please check.
BrainPy only supports to monitor Variables. Monitoring Variables’ trajectory is meaningful for they are dynamically changed. What is not marked as Variable will be compiled as constants.
try:
bp.StructRunner(HH(size=1), monitors=['gNa']).build_monitors()
except Exception as e:
print(type(e).__name__, ":", e)
RunningError : "gNa" in <brainpy.dyn.neurons.biological_models.HH object at 0x000002827DA5C430> is not a dynamically changed Variable, its value will not change, we think there is no need to monitor its trajectory.
The monitors in BrainPy only record the flattened tensor values. This means if the target variable is a matrix with the shape of (N, M)
, the resulting trajectory value in the monitor after running T
times will be a tensor with the shape of (T, N x M)
.
class MatrixVarModel(bp.DynamicalSystem):
def __init__(self, **kwargs):
super(MatrixVarModel, self).__init__(**kwargs)
self.a = bm.Variable(bm.zeros((4, 4)))
def update(self, _t, _dt):
self.a += 0.01
model = MatrixVarModel()
duration = 10
runner = bp.StructRunner(model, monitors=['a'])
runner.run(duration)
print(f'The expected shape of "model.mon.a" is: {(int(duration/bm.get_dt()), model.a.size)}')
print(f'The actual shape of "model.mon.a" is: {runner.mon.a.shape}')
The expected shape of "model.mon.a" is: (500, 16)
The actual shape of "model.mon.a" is: (500, 16)
Monitoring Variables at the Given Indices#
Sometimes we do not care about all the the contents in a variable. We may be only interested in the values at the certain indices. Moreover, for a huge network with a long-time simulation, monitors will comsume a large part of RAM. Therefore, monitoring variables only at the selected indices will be more applicable. BrainPy supports monitoring a part of elements in a Variable with the format of tuple like this:
runner = bp.StructRunner(
HH(10),
monitors=['V', # monitor all values of Variable 'V'
('spike', [1, 2, 3])], # monitor values of Variable at index of [1, 2, 3]
inputs=('input', 10.)
)
runner.run(100.)
print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)
Or we can use a dictionary to specify the target indices of a variable:
runner = bp.StructRunner(
HH(10),
monitors={'V': None, # 'None' means all values will be monitored
'spike': [1, 2, 3]}, # specifying the target indices
inputs=('input', 10.),
)
runner.run(100.)
print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)
Also, we can directly instantiate a brainpy.Monitor
class:
runner = bp.StructRunner(
HH(10),
monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])]),
inputs=('input', 10.),
)
runner.run(100.)
print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)
runner = bp.StructRunner(
HH(10),
monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]}),
inputs=('input', 10.),
)
runner.run(100.)
print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)
Note
Because brainpy.Monitor
records a flattened tensor variable, if users want to record a part of a multi-dimentional variable, they must provide the indices corrsponding to the flattened tensor.
Saving and Loading#
Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.
import brainpy as bp
bp.math.set_platform('cpu')
Saving and loading variables#
Model saving and loading in BrainPy are implemented with .save_states()
and .load_states()
functions.
BrainPy supports saving and loading model variables with various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’s pickle utility).mat
(Matlab file format)
Here’s a simple example:
class EINet(bp.Network):
def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = bp.models.LIF(num_exc, **pars, method=method)
I = bp.models.LIF(num_inh, **pars, method=method)
E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
E2E = bp.models.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
E2I = bp.models.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
E=0., g_max=0.6, tau=5., method=method)
I2E = bp.models.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
I2I = bp.models.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
E=-80., g_max=6.7, tau=10., method=method)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
net = EINet()
import os
if not os.path.exists('./data'):
os.makedirs('./data')
# model saving
net.save_states('./data/net.h5')
# model loading
net.load_states('./data/net.h5')
.save_states(filename, all_var=None)
function receives a string to specify the output file name. Ifall_vars
is not provided, BrainPy will retieve all variables in the model though the relative path..load_states(filename, verbose, check_missing)
function receives several arguments. The first is a string of the output file name. The second “verbose” specifies whether report the loading progress. The final argument “check_missing” will warn the variables of the model which missed in the output file.
# model loading with warning and checking
net.load_states('./data/net.h5', verbose=True, check_missing=True)
WARNING:brainpy.base.io:There are variable states missed in ./data/net.h5. The missed variables are: ['ExpCOBA0.pre.V', 'ExpCOBA0.pre.input', 'ExpCOBA0.pre.refractory', 'ExpCOBA0.pre.spike', 'ExpCOBA0.pre.t_last_spike', 'ExpCOBA1.pre.V', 'ExpCOBA1.pre.input', 'ExpCOBA1.pre.refractory', 'ExpCOBA1.pre.spike', 'ExpCOBA1.pre.t_last_spike', 'ExpCOBA1.post.V', 'ExpCOBA1.post.input', 'ExpCOBA1.post.refractory', 'ExpCOBA1.post.spike', 'ExpCOBA1.post.t_last_spike', 'ExpCOBA2.pre.V', 'ExpCOBA2.pre.input', 'ExpCOBA2.pre.refractory', 'ExpCOBA2.pre.spike', 'ExpCOBA2.pre.t_last_spike', 'ExpCOBA2.post.V', 'ExpCOBA2.post.input', 'ExpCOBA2.post.refractory', 'ExpCOBA2.post.spike', 'ExpCOBA2.post.t_last_spike', 'ExpCOBA3.pre.V', 'ExpCOBA3.pre.input', 'ExpCOBA3.pre.refractory', 'ExpCOBA3.pre.spike', 'ExpCOBA3.pre.t_last_spike'].
Loading E.V ...
Loading E.input ...
Loading E.refractory ...
Loading E.spike ...
Loading E.t_last_spike ...
Loading ExpCOBA0.g ...
Loading ExpCOBA0.pre_spike.data ...
Loading ExpCOBA0.pre_spike.in_idx ...
Loading ExpCOBA0.pre_spike.out_idx ...
Loading ExpCOBA1.g ...
Loading ExpCOBA1.pre_spike.data ...
Loading ExpCOBA1.pre_spike.in_idx ...
Loading ExpCOBA1.pre_spike.out_idx ...
Loading ExpCOBA2.g ...
Loading ExpCOBA2.pre_spike.data ...
Loading ExpCOBA2.pre_spike.in_idx ...
Loading ExpCOBA2.pre_spike.out_idx ...
Loading ExpCOBA3.g ...
Loading ExpCOBA3.pre_spike.data ...
Loading ExpCOBA3.pre_spike.in_idx ...
Loading ExpCOBA3.pre_spike.out_idx ...
Loading I.V ...
Loading I.input ...
Loading I.refractory ...
Loading I.spike ...
Loading I.t_last_spike ...
Note
By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned TensorCollector. Therefore, there will always be missing keys when loading the variables.
Custom saving and loading#
You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through .vars()
. Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of Variables.
The only gotcha to pay attention to is to avoid saving duplicated variables.
Variables#
In BrainPy, the JIT compilation for class objects relies on Variables. In this section, we are going to understand:
What is a
Variable
?What are the subtypes of
Variable
?How to update a
Variable
?
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
brainpy.math.Variable
#
brainpy.math.Variable
is a pointer referring to a tensor. It stores a tensor as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.
Tensors that are not marked as Variables will be JIT compiled as static data. Modifications of these tensors will be invalid or cause an error.
Creating a Variable
Passing a tensor into the brainpy.math.Variable
creates a Variable. For example:
b1 = bm.random.random(5)
b1
JaxArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458 ], dtype=float32)
b2 = bm.Variable(b1)
b2
Variable([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458 ], dtype=float32)
Accessing the value in a Variable
The data in a Variable can be obtained through .value
.
b2.value
DeviceArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458 ], dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
Supported operations on Variables
Variables support almost all the operations for tensors. Actually, brainpy.math.Variable
is a subclass of brainpy.math.ndarray
.
isinstance(b2, bm.ndarray)
True
isinstance(b2, bm.JaxArray)
True
# `bp.math.ndarray` is an alias for `bp.math.JaxArray` in 'jax' backend
bm.ndarray is bm.JaxArray
True
Note
After performing any operation on a Variable, the resulting value will be a JaxArray (brainpy.math.ndarray
is an alias for brainpy.math.JaxArray
). This means that the Variable can only be used to refer to a single value.
b2 + 1.
JaxArray([1.9116168, 1.6901083, 1.4392058, 1.1322064, 1.771458 ], dtype=float32)
b2 ** 2
JaxArray([0.8310452 , 0.47624946, 0.1929017 , 0.01747854, 0.5951475 ], dtype=float32)
bm.floor(b2)
JaxArray([0., 0., 0., 0., 0.], dtype=float32)
Subtypes of Variable
#
brainpy.math.Variable
has several subtypes, including brainpy.math.TrainVar
, brainpy.math.Parameter
, and brainpy.math.RandomState
. Subtypes can also be customized and extended by users.
1. TrainVar#
brainpy.math.TrainVar
is a trainable variable and a subclass of brainpy.math.Variable
. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.
b = bm.random.rand(4)
b
JaxArray([0.59062696, 0.618052 , 0.84173155, 0.34012556], dtype=float32)
bm.TrainVar(b)
TrainVar([0.59062696, 0.618052 , 0.84173155, 0.34012556], dtype=float32)
2. Parameter#
brainpy.math.Parameter
is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable
. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets
method (please see Base class).
b = bm.random.rand(1)
b
JaxArray([0.14782536], dtype=float32)
bm.Parameter(b)
Parameter([0.14782536], dtype=float32)
3. RandomState#
brainpy.math.random.RandomState
is also a subclass of brainpy.math.Variable
. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.
state = bm.random.RandomState(seed=1234)
state
RandomState([ 0, 1234], dtype=uint32)
# perform a "random" sampling
state.random(1)
state # the value changed
RandomState([2113592192, 1902136347], dtype=uint32)
# perform a "sample" sampling
state.sample(1)
state # the value changed too
RandomState([1076515368, 3893328283], dtype=uint32)
Every instance of RandomState can create a new seed from the current seed with .split_key()
.
state.split_key()
DeviceArray([3028232624, 826525938], dtype=uint32)
It can also create multiple seeds from the current seed with .split_keys(n)
. This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.
state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
[1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
[2548793527, 3057026599],
[ 874320145, 4142002431],
[3368470122, 3462971882],
[1756854521, 1662729797]], dtype=uint32)
There is a default RandomState in brainpy.math.random
module: DEFAULT
.
bm.random.DEFAULT
RandomState([601887926, 339370966], dtype=uint32)
The inherent random methods like randint()
, rand()
, shuffle()
, etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed()
method.
bm.random.seed(654321)
bm.random.DEFAULT
RandomState([ 0, 654321], dtype=uint32)
In-place updating#
In BrainPy, the transformations (like JIT) usually need to update variables or tensors in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.
For example, here we have a variable a
.
a = bm.Variable(bm.zeros(5))
The ids of the variable and the data stored in the variable are:
id_of_a = id(a)
id_of_data = id(a.value)
assert id_of_a != id_of_data
print('id(a) = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a) = 2101001001088
id(a.value) = 2101018127136
In-place update (here we use [:]
) does not change the pointer refered to the variable but changes its data:
a[:] = 1.
print('id(a) = ', id(a))
print('id(a.value) = ', id(a.value))
id(a) = 2101001001088
id(a.value) = 2101019514880
print('(id(a) == id_of_a) =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a) = True
(id(a.value) == id_of_data) = False
However, once you do not use in-place operators to assign data, the id that the variable a
refers to will change. This will cause serious errors when using transformations in BrainPy.
a = 10.
print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) = 2101001187280
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.JaxArray``.
Here, we list several commonly used in-place operators.
v = bm.Variable(bm.arange(10))
old_id = id(v)
def check_no_change(new_v):
assert id(new_v) == old_id, 'Variable has been changed.'
1. Indexing and slicing#
Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.
Indexing: v[i] = a
or v[(1, 3)] = c
(index multiple values)
v[0] = 1
check_no_change(v)
Slicing: v[i:j] = b
v[1: 2] = 1
check_no_change(v)
Slicing all values: v[:] = d
, v[...] = e
v[:] = 0
check_no_change(v)
v[...] = bm.arange(10)
check_no_change(v)
2. Augmented assignment#
All augmented assignment are in-place operations, which include
add:
+=
subtract:
-=
divide:
/=
multiply:
*=
floor divide:
//=
modulo:
%=
power:
**=
and:
&=
or:
|=
xor:
^=
left shift:
<<=
right shift:
>>=
v += 1
check_no_change(v)
v *= 2
check_no_change(v)
v |= bm.random.randint(0, 2, 10)
check_no_change(v)
v **= 2
check_no_change(v)
v >>= 2
check_no_change(v)
3. .value
assignment#
Another way to in-place update a variable is to assign new data to .value
. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.
v.value = bm.arange(10)
check_no_change(v)
try:
v.value = bm.asarray(1.)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
v.value = bm.random.random(10)
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
4. .update()
method#
Actually, the .value
assignment is the same operation as the .update()
method. Users who want a safe assignment can choose this method too.
v.update(bm.random.randint(0, 20, size=10))
try:
v.update(bm.asarray(1.))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
v.update(bm.random.random(10))
except Exception as e:
print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
Base Class#
In this section, we are going to talk about:
The
Base
class for the BrainPy ecosystemThe
Collector
to facilitate variable collection and manipulation.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
brainpy.Base
#
The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.Base
.
A Base object can have many variables, children Base objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.
class FHN(bp.Base):
def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FHN, self).__init__(name=name)
# parameters
self.num = num
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth
# variables
self.V = bm.Variable(bm.zeros(num))
self.w = bm.Variable(bm.zeros(num))
self.spike = bm.Variable(bm.zeros(num, dtype=bool))
# integral
self.integral = bp.odeint(method='rk4', f=self.derivative)
def derivative(self, V, w, t, Iext):
dw = (V + self.a - self.b * w) / self.tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
def update(self, _t, _dt, x):
V, w = self.integral(self.V, self.w, _t, x)
self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
self.w[:] = w
self.V[:] = V
Note this model has three variables: self.V
, self.w
, and self.spike
. It also has an integrator self.integral
.
The naming system#
Every Base object has a unique name. Users can specify a unique name when you instantiate a Base class. A used name will cause an error.
FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
FHN(10, name='Y').name
except Exception as e:
print(type(e).__name__, ':', e)
If a name is not specified to the Base oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name + number_of_instances
. For example, FHN0
, FHN1
, etc.
FHN(10).name
'FHN0'
FHN(10).name
'FHN1'
Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.
Collection functions#
Three important collection functions are implemented for each Base object. Specifically, they are:
nodes()
: to collect all instances of Base objects, including children nodes in a node.vars()
: to collect all variables defined in the Base node and in its children nodes.
fhn = FHN(10)
All variables in a Base object can be collected through Base.vars()
. The returned container is a TensorCollector (a subclass of Collector
).
vars = fhn.vars()
vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.TensorCollector
All nodes in the model can also be collected through one method Base.nodes()
. The result container is an instance of Collector.
nodes = fhn.nodes()
nodes # note: integrator is also a node
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>,
'FHN2': <__main__.FHN at 0x2155a7a65e0>}
type(nodes)
brainpy.base.collector.Collector
All integrators can be collected by:
ints = fhn.nodes().subset(bp.integrators.Integrator)
ints
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>}
type(ints)
brainpy.base.collector.Collector
Now, let’s make a more complicated model by using the previously defined model FHN
.
class FeedForwardCircuit(bp.Base):
def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FeedForwardCircuit, self).__init__(name=name)
self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
self.conn = bm.ones((num1, num2), dtype=bool) * w
bm.fill_diagonal(self.conn, 0.)
def update(self, _t, _dt, x):
self.pre.update(_t, _dt, x)
x2 = self.pre.spike @ self.conn
self.post.update(_t, _dt, x2)
This model FeedForwardCircuit
defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN
). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w
.
net = FeedForwardCircuit(8, 5)
We can retrieve all variables by .vars()
:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
And retrieve all nodes (instances of the Base class) by .nodes()
:
net.nodes()
{'FHN3': <__main__.FHN at 0x2155ace3130>,
'FHN4': <__main__.FHN at 0x2155a798d30>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>,
'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x2155a743c40>}
If we only care about a subtype of class, we can retrieve them through:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
Absolute paths#
It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.
For absolute paths, all keys in the resulting Collector (Base.nodes()
) has the format of key = node_name [+ field_name]
.
.nodes() example 1: In the above fhn
instance, there are two nodes: “fnh” and its integrator “fhn.integral”.
fhn.integral.name, fhn.name
('RK45', 'FHN2')
Calling .nodes()
returns their names and models.
fhn.nodes().keys()
dict_keys(['RK45', 'FHN2'])
.nodes() example 2: In the above net
instance, there are five nodes:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0')
Calling .nodes()
also returns the names and instances of all models.
net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0'])
.vars() example 1: In the above fhn
instance, there are three variables: “V”, “w” and “input”. Calling .vars()
returns a dict of <node_name + var_name, var_value>
.
fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
.vars() example 2: This also applies in the net
instance:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Relative paths#
Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre
instance in the net
can be accessed by
net.pre
<__main__.FHN at 0x2155ace3130>
Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net
are:
net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x2155a743c40>,
'pre': <__main__.FHN at 0x2155ace3130>,
'post': <__main__.FHN at 0x2155a798d30>,
'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
However, nodes retrieved from the start point of net.pre
will be:
net.pre.nodes('relative')
{'': <__main__.FHN at 0x2155ace3130>,
'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>}
Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net
include:
net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'post.spike': Variable([False, False, False, False, False], dtype=bool),
'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
While variables relatively accessed from net.post
are:
net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'spike': Variable([False, False, False, False, False], dtype=bool),
'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Elements in containers#
One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.
class ATest(bp.Base):
def __init__(self):
super(ATest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()
The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars()
and nodes()
.
t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x2155ae60a00>}
To solve this problem, BrianPy provides implicit_vars
and implicit_nodes
(an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars
and integrators and nodes registered in implicit_nodes
can be retrieved by collection functions.
class AnotherTest(bp.Base):
def __init__(self):
super(AnotherTest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)} # the input must be a dict
)
self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()} # the input must be a dict
)
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator.
# Therefore, there are five Base objects.
t2.nodes()
{'T1': <__main__.FHN at 0x2155ae6bca0>,
'T2': <__main__.FHN at 0x2155ae6b3a0>,
'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae6f8e0>,
'RK411': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae66610>,
'AnotherTest0': <__main__.AnotherTest at 0x2155ae6f0d0>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.
t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T1.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'T2.spike': Variable([False, False, False, False, False], dtype=bool),
'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32),
'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32)}
Saving and loading#
Because Base.vars()
returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by Base.save_states()
and Base.load_states()
.
Save#
Base.save_states(PATH, [vars])
Models exported from BrainPy support various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’spickle
utility).mat
(Matlab file format)
net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
Load#
Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')
Collector#
Collection functions return an brainpy.Collector
that is a dictionary mapping names to elements. It has some useful methods.
subset()
#
Collector.subset(cls)
returns a part of elements whose type is the given cls
. For example, Base.nodes()
returns all instances of Base class. If you are only interested in one type, like ODEIntegrator
, you can use:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
Actually, Collector.subset(cls)
travels all the elements in this collection, and find the element whose type matches the given cls
.
unique()
#
It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.
Therefore, the Collector provides Collector.unique()
to handle this problem automatically. Collector.unique()
returns a copy of collection in which all elements are unique.
class ModelA(bp.Base):
def __init__(self):
super(ModelA, self).__init__()
self.a = bm.Variable(bm.zeros(5))
class SharedA(bp.Base):
def __init__(self, source):
super(SharedA, self).__init__()
self.source = source
self.a = source.a # shared variable
class Group(bp.Base):
def __init__(self):
super(Group, self).__init__()
self.A = ModelA()
self.A_shared = SharedA(self.A)
g = Group()
g.vars('relative') # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique() # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative') # "ModelA" is accessed twice
{'': <__main__.Group at 0x2155b13b550>,
'A': <__main__.ModelA at 0x2155b13b460>,
'A_shared': <__main__.SharedA at 0x2155a7a6580>,
'A_shared.source': <__main__.ModelA at 0x2155b13b460>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x2155b13b550>,
'A': <__main__.ModelA at 0x2155b13b460>,
'A_shared': <__main__.SharedA at 0x2155a7a6580>}
update()
#
The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]
) and Collector.update()
will check whether the same key is mapped to a different value. If it is, an error will occur.
tc = bp.Collector({'a': bm.zeros(10)})
tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
tc['a'] = bm.zeros(1) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
tc.update({'a': bm.ones(1)}) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
replace()
#
Collector.replace(old_key, new_value)
is used to update the value of a key.
tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))
tc
{'a': JaxArray([1., 1., 1.], dtype=float32)}
__add()__
#
Two Collectors can be merged.
a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})
a + b
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'b': JaxArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}
TensorCollector#
TensorCollector
is subclass of Collector
, but it is specifically to collect tensors.
Compilation#
In this section, we are going to talk about code compilation that can accelerate your model running performance.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
brainpy.math.jit()
#
JAX provides JIT compilation jax.jit()
for pure functions.In most cases, however, we code with Python classes. brainpy.math.jit()
is intended to extend just-in-time compilation to class objects.
JIT compilation for class objects#
The constraints for class-object JIT ciompilation include:
The JIT target must be a subclass of
brainpy.Base
.Dynamically changed variables must be labeled as
brainpy.math.Variable
.Updating Variables must be accomplished by in-place operations.
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w[:] = self.w - u
# The above line can also be expressed as:
#
# self.w.value = self.w - u
#
# or,
#
# self.w.update(self.w - u)
In this example, weight self.w is a dynamically changed variable, thus marked as Variable
. During the update phase __call__()
, self.w is in-place updated through self.w[:] = ...
. Alternatively, one can replace the data in the variable by self.w.value = ...
or self.w.update(...)
.
Now this logistic regression can be accelerated by JIT compilation.
num_dim, num_points = 10, 200000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr = LogisticRegression(10)
%timeit lr(points, labels)
3.73 ms ± 589 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr_jit = bm.jit(lr)
%timeit lr_jit(points, labels)
1.75 ms ± 57.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
JIT mechanism#
The mechanism of JIT compilation is that BrainPy automatically transforms your class methods into functions.
brainpy.math.jit()
receives a dyn_vars
argument, which denotes the dynamically changed variables. If it is not provided by users, BrainPy will automatically detect them by calling Base.vars()
(only variables labeled as Variable
will be automatically retrieved by Base.vars()). Once receiving “dyn_vars”, BrainPy will treat “dyn_vars” as function arguments and then transform class objects into functions.
import types
isinstance(lr_jit, types.FunctionType) # "lr" is class, while "lr_jit" is a function
True
Therefore, the secrete of brainpy.math.jit()
is providing “dyn_vars”. No matter your target is a class object, a method in the class object, or a pure function, if there are dynamically changed variables, you just pack them into brainpy.math.jit()
as “dyn_vars”. Then, all the compilation and acceleration will be handled by BrainPy automatically. Let’s illustrate this by several examples.
Example 1: JIT compiled methods in a class#
In this example, we try to run a method just-in-time in a class, in which the object variable are used to compute the final results.
class Linear(bp.Base):
def __init__(self, n_in, n_out):
super(Linear, self).__init__()
self.w = bm.random.random((n_in, n_out))
self.b = bm.zeros(n_out)
def update(self, x):
return x @ self.w + self.b
x = bm.zeros(10) # the input data
l = Linear(10, 3) # the class we need
First, we mark “w” and “b” as dynamically changed variables. Changing “w” or “b” will change the final results.
update1 = bm.jit(
l.update, dyn_vars=[l.w, l.b] # make 'w' and 'b' dynamically change
)
update1(x) # x is 0., b is 0., therefore y is 0.
JaxArray([0., 0., 0.], dtype=float32)
l.b[:] = 1. # change b to 1, we expect y will be 1 too
update1(x)
JaxArray([1., 1., 1.], dtype=float32)
This time, we only mark “w” as a dynamically changed variable. We will find that no matter how “b” is modified, the results will not change.
update2 = bm.jit(
l.update, dyn_vars=[l.w] # only make 'w' dynamically change
)
update2(x)
JaxArray([1., 1., 1.], dtype=float32)
l.b[:] = 2. # change b to 2
update2(x) # while y will not be 2
JaxArray([1., 1., 1.], dtype=float32)
Example 2: JIT compiled functions#
Now, we change the above “Linear” object to a function.
n_in = 10; n_out = 3
w = bm.random.random((n_in, n_out))
b = bm.zeros(n_out)
def update(x):
return x @ w + b
If we do not provide dyn_vars
, “w” and “b” will be compiled as constant values.
update1 = bm.jit(update)
update1(x)
JaxArray([0., 0., 0.], dtype=float32)
b[:] = 1. # modify the value of 'b' will not
# change the result, because in the
# jitted function, 'b' is already
# a constant
update1(x)
JaxArray([0., 0., 0.], dtype=float32)
Providing “w” and “b” as dyn_vars
will make them dynamically changed again.
update2 = bm.jit(update, dyn_vars=(w, b))
update2(x)
JaxArray([1., 1., 1.], dtype=float32)
b[:] = 2. # change b to 2, while y will not be 2
update2(x)
JaxArray([2., 2., 2.], dtype=float32)
Example 3: JIT compiled neural networks#
Now, let’s use SGD to train a neural network with JIT acceleration. Here we use the autograd function brainpy.math.grad()
, which will be discussed in detail in the next section.
class LinearNet(bp.Base):
def __init__(self, n_in, n_out):
super(LinearNet, self).__init__()
# weights
self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
self.b = bm.TrainVar(bm.zeros(n_out))
self.r = bm.TrainVar(bm.random.random((n_out, 1)))
def update(self, x):
h = x @ self.w + self.b
return h @ self.r
def loss(self, x, y):
predict = self.update(x)
return bm.mean((predict - y) ** 2)
ln = LinearNet(100, 200)
# provide the variables want to update
opt = bm.optimizers.SGD(lr=1e-6, train_vars=ln.vars())
# provide the variables require graidents
f_grad = bm.grad(ln.loss, grad_vars=ln.vars(), return_value=True)
def train(X, Y):
grads, loss = f_grad(X, Y)
opt.update(grads)
return loss
# JIT the train function
train_jit = bm.jit(train, dyn_vars=ln.vars() + opt.vars())
xs = bm.random.random((1000, 100))
ys = bm.random.random((1000, 1))
for i in range(30):
loss = train_jit(xs, ys)
print(f'Train {i}, loss = {loss:.2f}')
Train 0, loss = 6649731.50
Train 1, loss = 3748688.50
Train 2, loss = 2126231.00
Train 3, loss = 1210147.88
Train 4, loss = 690106.50
Train 5, loss = 393984.28
Train 6, loss = 225071.75
Train 7, loss = 128625.49
Train 8, loss = 73524.97
Train 9, loss = 42035.37
Train 10, loss = 24035.91
Train 11, loss = 13746.33
Train 12, loss = 7863.82
Train 13, loss = 4500.70
Train 14, loss = 2577.91
Train 15, loss = 1478.59
Train 16, loss = 850.07
Train 17, loss = 490.72
Train 18, loss = 285.26
Train 19, loss = 167.80
Train 20, loss = 100.63
Train 21, loss = 62.24
Train 22, loss = 40.28
Train 23, loss = 27.73
Train 24, loss = 20.55
Train 25, loss = 16.45
Train 26, loss = 14.10
Train 27, loss = 12.76
Train 28, loss = 11.99
Train 29, loss = 11.56
RandomState#
We have talked about RandomState in the Variables section. RandomeState is also a Variable. Therefore, if the default RandomState (brainpy.math.random.DEFAULT
) is used in your function, you should mark it as one of the dyn_vars
in the function. Otherwise, they will be treated as constants and the jitted function will always return the same value.
def function():
return bm.random.normal(0, 1, size=(10,))
f1 = bm.jit(function)
f1() == f1()
JaxArray([ True, True, True, True, True, True, True, True,
True, True], dtype=bool)
The correct way to make JIT for this function is:
bm.random.seed(1234)
f2 = bm.jit(function, dyn_vars=bm.random.DEFAULT)
f2() == f2()
JaxArray([False, False, False, False, False, False, False, False,
False, False], dtype=bool)
Static arguments#
Static arguments are treated as static/constant in the jitted function.
Two things must be marked as static: numerical arguments used in the conditional syntax (bool values or resulting in bool values) and strings. Otherwise, an error will raise.
@bm.jit
def f(x):
if x < 3: # this will cause error
return 3. * x ** 2
else:
return -4 * x
try:
f(1.)
except Exception as e:
print(type(e), e)
<class 'jax._src.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at C:\Users\adadu\AppData\Local\Temp\ipykernel_44816\1408095738.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Simply speaking, arguments resulting in boolean values must be declared as static arguments. In brainpy.math.jit()
function, we can set the names of static arguments.
def f(x):
if x < 3: # this will cause error
return 3. * x ** 2
else:
return -4 * x
f_jit = bm.jit(f, static_argnames=('x',))
f_jit(x=1.)
DeviceArray(3., dtype=float32, weak_type=True)
However, it’s worth noting that calling the jitted function with different values for these static arguments will trigger recompilation. Therefore, declaring static arguments may be suitable to the following situations:
Boolean arguments.
Arguments that only have several possible values.
If the argument value change significantly, you’d better not declare it as static.
For more information, please refer to the jax.jit API.
Differentiation#
In this section, we are going to talk about how to realize automatic differentiation on your variables in a function or a class object. In current machine learning systems, gradients are commonly used in various situations. Therefore, we should understand:
How to calculate derivatives of arbitrary complex functions?
How to compute high-order gradients?
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
Preliminary#
Every autograd function in BrainPy has several keywords. All examples below are illustrated through brainpy.math.grad(). Other autograd functions have the same settings.
argnums
and grad_vars
#
The autograd functions in BrainPy can compute derivatives of function arguments (specified by argnums
) or non-argument variables (specified by grad_vars
). For instance, the following is a linear readout model:
class Linear(bp.Base):
def __init__(self):
super(Linear, self).__init__()
self.w = bm.random.random((1, 10))
self.b = bm.zeros(1)
def update(self, x):
r = bm.dot(self.w, x) + self.b
return r.sum()
l = Linear()
If we try to focus on the derivative of the argument “x” when calling the update function, we can set this through argnums
:
grad = bm.grad(l.update, argnums=0)
grad(bm.ones(10))
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468 ], dtype=float32)
By contrast, if you focus on the derivatives of parameters “self.w” and “self.b”, we should label them with grad_vars
:
grad = bm.grad(l.update, grad_vars=(l.w, l.b))
grad(bm.ones(10))
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
DeviceArray([1.], dtype=float32))
If we pay attention to the derivatives of both argument “x” and parameters “self.w” and “self.b”, argnums
and grad_vars
can be used together. In this condition, the gradient function will return gradients with the format of (var_grads, arg_grads)
, where arg_grads
refers to the gradients of “argnums” and var_grads
refers to the gradients of “grad_vars”.
grad = bm.grad(l.update, grad_vars=(l.w, l.b), argnums=0)
var_grads, arg_grads = grad(bm.ones(10))
var_grads
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
DeviceArray([1.], dtype=float32))
arg_grads
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468 ], dtype=float32)
return_value
#
As is mentioned above, autograd functions return a function which computes gradients regardless of the returned value. Sometimes, however, we care about the value the function returns, not just the gradients. In this condition, you can set return_value=True
in the autograd function.
grad = bm.grad(l.update, argnums=0, return_value=True)
gradient, value = grad(bm.ones(10))
gradient
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468 ], dtype=float32)
value
DeviceArray(4.6318984, dtype=float32)
has_aux
#
In some situations, we are interested in the intermediate values in a function, and has_aux=True
can be of great help. The constraint is that you must return values with the format of (loss, aux_data)
. For instance,
class LinearAux(bp.Base):
def __init__(self):
super(LinearAux, self).__init__()
self.w = bm.random.random((1, 10))
self.b = bm.zeros(1)
def update(self, x):
dot = bm.dot(self.w, x)
r = (dot + self.b).sum()
return r, (r, dot) # here the aux data is a tuple, includes the loss and the dot value.
# however, aux can be arbitrary complex.
l2 = LinearAux()
grad = bm.grad(l2.update, argnums=0, has_aux=True)
gradient, aux = grad(bm.ones(10))
gradient
JaxArray([0.20289445, 0.4745227 , 0.36053288, 0.94524395, 0.8360598 ,
0.06507981, 0.7748591 , 0.8377187 , 0.5767547 , 0.47604012], dtype=float32)
aux
(DeviceArray(5.5497055, dtype=float32), JaxArray([5.5497055], dtype=float32))
When multiple keywords (argnums
, grad_vars
, has_aux
orreturn_value
) are set simulatenously, the return format of the gradient function can be inspected through the corresponding API documentation brainpy.math.grad().
brainpy.math.grad()
#
brainpy.math.grad() takes a function/object (\(f : \mathbb{R}^n \to \mathbb{R}\)) as the input and returns a new function (\(\partial f(x) \to \mathbb{R}^n\)) which computes the gradient of the original function/object. It’s worthy to note that brainpy.math.grad()
only supports returning scalar values.
Pure functions#
For pure function, the gradient is taken with respect to the first argument:
def f(a, b):
return a * 2 + b
grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32, weak_type=True)
However, this can be controlled via the argnums
argument.
grad_f2 = bm.grad(f, argnums=(0, 1))
grad_f2(2., 1.)
(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(1., dtype=float32, weak_type=True))
Class objects#
For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars
and argnums
setting:
class F(bp.Base):
def __init__(self):
super(F, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
vv = ab2 + c
return vv.mean()
f = F()
The grad_vars
can be a JaxArray, or a list/tuple/dict of JaxArray.
bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': DeviceArray([2.], dtype=float32),
'F0.b': DeviceArray([2.], dtype=float32)}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
(DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
If there are dynamically changed values in the gradient function, you can provide them in the dyn_vars
argument.
class F2(bp.Base):
def __init__(self):
super(F2, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab = ab * 2
self.a.value = ab
return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
DeviceArray([2.], dtype=float32)
Besides, if you are interested in the gradient of the input value, please use the argnums
argument. Then, the gradient function will return (grads_of_grad_vars, grads_of_args)
.
class F3(bp.Base):
def __init__(self):
super(F3, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c, d):
ab = self.a * self.b
ab = ab * 2
return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : (DeviceArray(3., dtype=float32, weak_type=True), DeviceArray(10., dtype=float32, weak_type=True))
Actually, it is recommended to provide all dynamically changed variables, whether or not they are updated in the gradient function, in the dyn_vars
argument.
Auxiliary data#
Usually, we want to get the loss value, or we want to return some intermediate variables during the gradient computation. In these situation, users can set has_aux=True
to return auxiliary data and set return_value=True
to return the loss value.
# return loss
grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)
print('grad: ', grad)
print('loss: ', loss)
grad: [2.]
loss: 12.0
class F4(bp.Base):
def __init__(self):
super(F4, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
loss = (ab + c).mean()
return loss, (ab, ab2)
f4 = F4()
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)
print('grad: ', grad)
print('aux_data: ', aux_data)
grad: [1.]
aux_data: (JaxArray([1.], dtype=float32), JaxArray([2.], dtype=float32))
Any function used to compute gradients through ``brainpy.math.grad()`` must return a scalar value. Otherwise an error will raise.
try:
bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output had shape: (2,).
# this is right
bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray([0.5, 0.5], dtype=float32)
brainpy.math.vector_grad()
#
If users want to take gradients for a vector-output values, please use the brainpy.math.vector_grad() function. For example,
def f(a, b):
return bm.sin(b) * a
Gradients for vectors#
# vectors
a = bm.arange(5.)
b = bm.random.random(5)
bm.vector_grad(f)(a, b)
JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ], dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ], dtype=float32),
JaxArray([0. , 0.9801371, 1.7597246, 2.741662 , 3.9158623], dtype=float32))
Gradients for matrices#
# matrix
a = bm.arange(6.).reshape((2, 3))
b = bm.random.random((2, 3))
bm.vector_grad(f, argnums=1)(a, b)
JaxArray([[0. , 0.8662993, 1.1221857],
[2.9322515, 2.3293345, 3.024507 ]], dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([[0.45055482, 0.49952534, 0.8277529 ],
[0.21131878, 0.8129499 , 0.79630035]], dtype=float32),
JaxArray([[0. , 0.8662993, 1.1221857],
[2.9322515, 2.3293345, 3.024507 ]], dtype=float32))
Similar to brainpy.math.grad() , brainpy.math.vector_grad()
also supports derivatives of variables in a class object. Here is a simple example.
class Test(bp.Base):
def __init__(self):
super(Test, self).__init__()
self.x = bm.ones(5)
self.y = bm.ones(5)
def __call__(self):
return self.x ** 2 + self.y ** 3 + 10
t = Test()
bm.vector_grad(t, grad_vars=t.x)()
DeviceArray([2., 2., 2., 2., 2.], dtype=float32)
bm.vector_grad(t, grad_vars=(t.x, ))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),)
bm.vector_grad(t, grad_vars=(t.x, t.y))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),
DeviceArray([3., 3., 3., 3., 3.], dtype=float32))
Other operations like return_value
and has_aux
in brainpy.math.vector_grad() are the same as those in brainpy.math.grad() .
brainpy.math.jacobian()
#
Another way to take gradients of a vector-output value is using brainpy.math.jacobian(). brainpy.math.jacobian()
aims to automatically compute the Jacobian matrices \(\partial f(x) \in \mathbb{R}^{m \times n}\) by the given function \(f : \mathbb{R}^n \to \mathbb{R}^m\) at the given point of \(x \in \mathbb{R}^n\). Here, we will not go to the details of the implementation and usage of the brainpy.math.jacobian()
. Instead, we only show two examples about the pure function and class function.
Given the following function,
import jax.numpy as jnp
def f1(x, y):
a = 4 * x[1] ** 2 - 2 * x[2]
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
return r, a
_x = bm.array([1., 2., 3.])
_y = bm.array([10., 5.])
grads, vec, aux = bm.jacobian(f1, return_value=True, has_aux=True)(_x, _y)
grads
JaxArray([[10. , 0. , 0. ],
[ 0. , 0. , 25. ],
[ 0. , 16. , -2. ],
[ 1.6209068 , 0. , 0.84147096]], dtype=float32)
vec
DeviceArray([10. , 75. , 10. , 2.5244129], dtype=float32)
aux
DeviceArray(10., dtype=float32)
Given the following class objects,
class Test(bp.Base):
def __init__(self):
super(Test, self).__init__()
self.x = bm.array([1., 2., 3.])
def __call__(self, y):
a = self.x[0] * y[0]
b = 5 * self.x[2] * y[1]
c = 4 * self.x[1] ** 2 - 2 * self.x[2]
d = self.x[2] * jnp.sin(self.x[0])
r = jnp.asarray([a, b, c, d])
return r, (c, d)
t = Test()
f_grad = bm.jacobian(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)
(var_grads, arg_grads), value, aux = f_grad(_y)
var_grads
DeviceArray([[10. , 0. , 0. ],
[ 0. , 0. , 25. ],
[ 0. , 16. , -2. ],
[ 1.6209068 , 0. , 0.84147096]], dtype=float32)
arg_grads
JaxArray([[ 1., 0.],
[ 0., 15.],
[ 0., 0.],
[ 0., 0.]], dtype=float32)
value
DeviceArray([10. , 75. , 10. , 2.5244129], dtype=float32)
aux
(DeviceArray(10., dtype=float32), DeviceArray(2.5244129, dtype=float32))
For more details on automatical differentation, please see our API documentation.
Control Flows#
In this section, we are going to talk about how to build structured control flows with the BrainPy data structure JaxArray
. These control flows include
the for loop syntax,
the while loop syntax,
and the condition syntax.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
In JAX, the control flow syntax must be defined as structured control flows. the JaxArray
in BrainPy provides an easier syntax to make control flows.
Note
All the control flow syntax below is not re-implementations of JAX’s API for control flows. We only gurantee the following APIs are useful and intuitive when you use brainpy.math.JaxArray
.
brainpy.math.make_loop()
#
brainpy.math.make_loop()
is used to generate a for-loop function when you use JaxArray
.
Suppose that you are using several JaxArrays (grouped as dyn_vars
) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars
). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as
def for_loop_function(body_fun, dyn_vars, out_vars, xs):
ys = []
for x in xs:
# 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
results = body_fun(x)
ys.append([out_vars, results])
return ys
In BrainPy, you can define this logic using brainpy.math.make_loop()
:
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)
hist_of_out_vars = loop_fun(xs)
Or,
loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)
hist_of_out_vars, hist_of_return_vars = loop_fun(xs)
Let’s implement a recurrent network to illustrate how to use this function.
class RNN(bp.dyn.DynamicalSystem):
def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
super(RNN, self).__init__(**kwargs)
# parameters
self.n_in = n_in
self.n_h = n_h
self.n_out = n_out
self.n_batch = n_batch
self.g = g
# weights
self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
self.b_ro = bm.TrainVar(bm.zeros((n_out,)))
# variables
self.h = bm.Variable(bm.random.random((n_batch, n_h)))
# function
self.predict = bm.make_loop(self.cell,
dyn_vars=self.vars(),
out_vars=self.h,
has_return=True)
def cell(self, x):
self.h.value = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
o = self.h @ self.w_ro + self.b_ro
return o
rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)
In the above RNN
model, we define a body function RNN.cell
for later for-loop over input values. The loop function is defined as self.predict
with bm.make_loop()
. We care about the history values of “self.h” and the readout value “o”, so we set out_vars=self.h
and has_return=True
.
xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape # the shape should be (num_time, ) + o.shape
(100, 5, 3)
If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs)
, where “xs” can be a JaxArray or a list/tuple/dict of JaxArray. For example:
a = bm.Variable(bm.zeros(10))
def body(x):
x1, x2 = x # "x" is a tuple/list of JaxArray
a.value += (x1 + x2)
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
Variable([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32)
a = bm.Variable(bm.zeros(10))
def body(x): # "x" is a dict of JaxArray
a.value += x['a'] + x['b']
loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
Variable([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
[15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
[21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
[28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
[36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
[45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
[55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]], dtype=float32)
dyn_vars
, out_vars
, xs
and the body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_vars
and body function returns. If has_return=True
, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns)
. If no values are interested, please set out_vars=None
, and the loop function only returns hist_of_out_vars
.
brainpy.math.make_while()
#
brainpy.math.make_while()
is used to generate a while-loop function when you use JaxArray
. It supports the following loop logic:
while condition:
statements
When using brainpy.math.make_while()
, condition should be wrapped as a cond_fun
function which returns a boolean value, and statements should be packed as a body_fun
function which does not support returned values:
while cond_fun(x):
body_fun(x)
where x
is the external input that is not iterated. All the iterated variables should be marked as JaxArray
. All JaxArray
s used in cond_fun
and body_fun
should be declared as dyn_vars
variables.
Let’s look an example:
i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))
def cond_f(x):
return i[0] < 10
def body_f(x):
i.value += 1.
counter.value += i
loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])
In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i
and counter
.
loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)
brainpy.math.make_cond()
#
brainpy.math.make_cond()
is used to generate a condition function you use JaxArray
. It supports the following conditional logic:
if True:
true statements
else:
false statements
When using brainpy.math.make_cond()
, true statements should be wrapped as a true_fun
function which implements logics under true assertion, and false statements should be wrapped as a false_fun
function which implements logics under false assertion. Neither function supports returning values.
if True:
true_fun(x)
else:
false_fun(x)
All the JaxArray
s used in true_fun
and false_fun
should be declared in the dyn_vars
argument. x
is used to receive the external input value.
Let’s make a try:
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
Here, we have two tensors. If true, tensor a
is added by 1; if false, tensor b
is subtracted by 1.
cond(pred=True)
a, b
(Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(True)
a, b
(Variable([2., 2.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False)
a, b
(Variable([2., 2.], dtype=float32), Variable([0., 0.], dtype=float32))
cond(False)
a, b
(Variable([2., 2.], dtype=float32), Variable([-1., -1.], dtype=float32))
Or, we define a conditional case which depends on the external input.
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += x
def false_f(x): b.value -= x
cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)
a, b
(Variable([10., 10.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False, 5.)
a, b
(Variable([10., 10.], dtype=float32), Variable([-4., -4.], dtype=float32))
Low-level Operator Customization#
BrainPy is built on Jax and can accelerate model running performance based on Just-in-Time(JIT) compilation. In order to enhance performance on CPU and GPU, we publish another package BrainPyLib
to provide several built-in low-level operators in synaptic computation. These operators are written in C++ and wrapped as Jax primitives by using XLA
. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce numba.cfunc
here and provide convenient interfaces for users to customize operators without touching the underlying logic.
import brainpy as bp
import brainpy.math as bm
from jax import jit
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
bm.set_platform('cpu')
In Computation with Sparse Connections section, we formally discuss the benefits of computation with our built-in operators. These operators are provided by brainpylib
package and can be accessed through brainpy.math
module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++ and converted into Jax/XLA compatible primitive by using Pybind11
.
It is not easy to write a C++ operator and implement a series of conversion. Users have to learn how to write a C++ operator, how to write a customized Jax primitive, and how to convert your C++ operator into a Jax primitive. Here are some links for users who prefer to dive into the details: Jax primitives, XLA custom calls.
However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface register_op
to register customized operators on CPU and GPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of numba.cfunc
, which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Parameters and return types of register_op
is listed in this api docs. Here is an example of using register_op
on CPU.
How to customize operators?#
CPU version#
First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up numba documentations for details.
def custom_op(outs, ins):
y, y1 = outs
x, x2 = ins
y[:] = x + 1
y1[:] = x2 + 2
There are some restrictions that users should know:
Parameters of the operators are
outs
andins
, corresponding to output variable(s) and input variable(s). The order cannot be changed.The function cannot have any return value.
Notice that in GPU version users should write kernel function according to numba cuda.jit documentation. When applying CPU function to GPU, users only need to implement CPU operators.
Then users should describe the shapes and types of the outputs, because jax/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:
a
ShapedArray
,a sequence of
ShapedArray
,a function, it should return correct output shapes of
ShapedArray
.
Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible.
def abs_eval_1(*ins):
# ins: inputs arguments, only shapes and types are accessible.
# Because custom_op outputs shapes and types are exactly the
# same as inputs, so here we can only return ordinary inputs.
return ins
The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know abs_eval_1
and abs_eval_2
are doing the same thing.
def abs_eval_2(*ins):
return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)
Now we have prepared for registering a CPU operator. register_op
will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:
op_name
: Name of the operator.cpu_func
: Customized operator of CPU version.out_shapes
: The shapes and types of the outputs.
z = jnp.ones((1, 2), dtype=jnp.float32)
# Users could try out_shapes=abs_eval_2 and see if the result is different
op = bm.register_op(
op_name='add',
cpu_func=custom_op,
out_shapes=abs_eval_1,
apply_cpu_func_to_gpu=False)
jit_op = jit(op)
print(jit_op(z, z))
[DeviceArray([[2., 2.]], dtype=float32), DeviceArray([[3., 3.]], dtype=float32)]
GPU version#
We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:
gpu_func
: Customized operator of CPU version.apply_cpu_func_to_gpu
: Whether to run kernel function on CPU for an alternative way for GPU version.
Warning
GPU operators will be wrapped by cuda.jit
in numba
, but numba
currently is not support to launch CUDA kernels from cfuncs
. For this reason, gpu_func
is none for default, and there will be an error if users pass a gpu operator to gpu_func
.
Therefore, BrainPy enables users to set apply_cpu_func_to_gpu
to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks.
Performance#
To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use event_sum
as an example. The implementation of event_sum
by using our customization is shown as below:
def abs_eval(events, indices, indptr, post_size, values):
return post_size
def event_sum_op(outs, ins):
post_val = outs
events, indices, indptr, post_size, values = ins
for i in range(len(events)):
if events[i]:
for j in range(indptr[i], indptr[i+1]):
index = indices[j]
old_value = post_val[index]
post_val[index] = values + old_value
event_sum = bm.register_op(op_name='event_sum', cpu_func=event_sum_op, out_shapes=abs_eval)
jit_event_sum = jit(event_sum)
Exponential COBA will be our benchmark for testing the speed. We will use built-in operator event_sum
first.
class ExpCOBA(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
method='exp_auto'):
super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')
# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
self.pre2post = self.conn.require('pre2post')
# variables
self.g = bm.Variable(bm.zeros(self.post.num))
# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
def update(self, _t, _dt):
self.g.value = self.integral(self.g, _t, dt=_dt)
# Built-in operator
# --------------------------------------------------------------------------------------
self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
# --------------------------------------------------------------------------------------
self.post.input += self.g * (self.E - self.post.V)
class EINet(bp.dyn.Network):
def __init__(self, scale=1.0, method='exp_auto'):
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = bp.models.LIF(num_exc, **pars, method=method)
I = bp.models.LIF(num_inh, **pars, method=method)
E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
net = EINet(scale=10., method='euler')
# simulation
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.628559827804565
The total time is 15.62 seconds. Next we use our customized operator.
class ExpCOBA(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
method='exp_auto'):
super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')
# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
self.pre2post = self.conn.require('pre2post')
# variables
self.g = bm.Variable(bm.zeros(self.post.num))
# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
def update(self, _t, _dt):
self.g.value = self.integral(self.g, _t, dt=_dt)
post_size = bm.zeros(self.post.num)
# Customized operator
# ------------------------------------------------------------------------------------------------------------
self.g += jit_event_sum(self.pre.spike, self.pre2post[0].value, self.pre2post[1].value, post_size, self.g_max)
# ------------------------------------------------------------------------------------------------------------
self.post.input += self.g * (self.E - self.post.V)
class EINet(bp.dyn.Network):
def __init__(self, scale=1.0, method='exp_auto'):
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = bp.models.LIF(num_exc, **pars, method=method)
I = bp.models.LIF(num_inh, **pars, method=method)
E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.
# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
net = EINet(scale=10., method='euler')
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.703513145446777
After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss.
Interoperation with other JAX frameworks#
import brainpy.math as bm
BrainPy can be easily interoperated with other JAX frameworks.
1. data are exchangeable in different frameworks.#
This can be realized because JaxArray
can be direactly converted to JAX ndarray or NumPy ndarray.
Convert a JaxArray
into a JAX ndarray.
# JaxArray.value is a JAX ndarray
b.value
DeviceArray([5, 1, 2, 3, 4], dtype=int32)
Convert a JaxArray
into a numpy ndarray.
# JaxArray can be easily converted to a numpy ndarray
np.asarray(b)
array([5, 1, 2, 3, 4])
Convert a numpy ndarray into a JaxArray
.
bm.asarray(np.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
Convert a JAX ndarray into a JaxArray
.
import jax.numpy as jnp
bm.asarray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
bm.JaxArray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
2. transformations in brainpy.math
also work on functions.#
APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.
import optax
# First create several useful functions.
network = bm.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))
def compute_loss(params, x, y):
y_pred = network(params, x)
loss = bm.mean(optax.l2_loss(y_pred, y))
return loss
@bm.jit
def train(params, opt_state, xs, ys):
grads = bm.grad(compute_loss)(params, xs.value, ys)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Generate some data
bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer
params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)
# A simple update loop
for _ in range(1000):
params, opt_state = train(params, opt_state, xs, ys)
assert bm.allclose(params, target_params), \
'Optimization should retrieve the target params used to generate the data.'
brainpy.base
module#
The base
module for whole BrainPy ecosystem.
This module provides the most fundamental class
Base
, and its associated helper classCollector
andArrayCollector
.For each instance of “Base” class, users can retrieve all the variables (or trainable variables), integrators, and nodes.
This module also provides a
Function
class to wrap user-defined functions. In each function, maybe several nodes are used, and users can initialize aFunction
by providing the nodes used in the function. Unfortunately,Function
class does not have the ability to gather nodes automatically.This module provides
io
helper functions to help users save/load model states, or share user’s customized model with others.This module provides
naming
tools to guarantee the unique nameing for each Base object.
Details please see the following.
Base Class#
|
The Base class for whole BrainPy ecosystem. |
Function Wrapper#
|
The wrapper for Python functions. |
Collectors#
A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
|
A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
Exporting and Loading#
|
Save variables into a HDF5 file. |
|
Save variables into a numpy file. |
|
Save variables into a pickle file. |
|
Save variables into a HDF5 file. |
|
Load variables in a HDF5 file. |
|
Load variables from a numpy file. |
|
Load variables from a pickle file. |
|
Load variables from a numpy file. |
Built-in mutable sequence. |
Naming Tools#
|
Check the uniqueness of the name for the object type. |
|
Get the unique name for the given object type. |
Clear the cached names. |
brainpy.math
module#
The math
module for whole BrainPy ecosystem.
This module provides basic mathematical operations, including:
numpy-like array operations
linear algebra functions
random sampling functions
discrete fourier transform functions
just-in-time compilation for class objects
automatic differentiation for class objects
dedicated operators for brain dynamics
activation functions
device switching
default type switching
and others
Details in the following.
Math Variables#
|
Multiple-dimensional array in JAX backend. |
alias of |
|
|
The pointer to specify the dynamical variable. |
|
The pointer to specify the trainable variable. |
|
The pointer to specify the parameter. |
Delay Variables#
|
|
|
Delay variable which has a fixed delay time length. |
|
Delay variable which has a fixed delay length. |
|
Neutral Time Delay. |
|
Neutral Length Delay. |
JIT Compilation#
The JIT compilation tools for JAX backend.
Just-In-Time compilation is implemented by the ‘jit()’ function
|
JIT (Just-In-Time) compilation for class objects. |
Operators#
|
The pre-to-post synaptic summation. |
|
The pre-to-post synaptic production. |
|
The pre-to-post synaptic maximization. |
|
The pre-to-post synaptic minimization. |
|
The pre-to-post synaptic mean computation. |
|
The pre-to-syn computation. |
|
The syn-to-post summation computation. |
|
The syn-to-post summation computation. |
|
The syn-to-post product computation. |
|
The syn-to-post maximum computation. |
|
The syn-to-post minimization computation. |
|
The syn-to-post mean computation. |
|
The syn-to-post softmax computation. |
|
The pre-to-post synaptic computation with event-driven summation. |
|
The pre-to-post synaptic computation with event-driven production. |
|
Sparse matrix multiplication. |
|
|
|
|
|
|
|
|
|
Converting the numba-jitted function in a Jax/XLA compatible primitive. |
Control Flows#
|
Make a for-loop function, which iterate over inputs. |
|
Make a while-loop function. |
|
Make a condition (if-else) function. |
|
Simple conditional statement (if-else) with instance of |
|
|
|
|
|
|
Automatic Differentiation#
|
Automatic gradient computation for functions or class objects. |
|
Take vector-valued gradients for function |
|
Extending automatic Jacobian (reverse-mode) of |
|
Extending automatic Jacobian (reverse-mode) of |
|
Extending automatic Jacobian (forward-mode) of |
|
Hessian of |
Activation Functions#
This module provides commonly used activation functions.
Activation functions are a critical part of the design of a neural network. The choice of activation function in the hidden layer will control how well the network model learns the training dataset. The choice of activation function in the output layer will define the type of predictions the model can make.
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Gaussian error linear unit activation function. |
|
Gated linear unit activation function. |
|
Hard \(\mathrm{tanh}\) activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU activation function |
|
Hard SiLU activation function |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
One-hot encodes the given indicies. |
|
Normalizes an array by subtracting mean and dividing by sqrt(var). |
|
|
|
Rectified Linear Unit 6 activation function. |
|
Sigmoid activation function. |
|
Soft-sign activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
SiLU activation function. |
|
SiLU activation function. |
|
Scaled exponential linear unit activation. |
|
|
|
Function#
|
Comparison Table#
Here is a list of NumPy APIs and its corresponding BrainPy implementations.
-
in BrainPy column denotes that implementation is not provided yet.
We welcome contributions for these functions.
Multi-dimensional Array#
NumPy |
brainpy.math |
jax.numpy |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
- |
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
Summary
Number of NumPy functions: 56
Number of functions covered by
brainpy.math
: 46Number of functions unique in
brainpy.math
: 8Number of functions covered by
jax.numpy
: 42
Array Operations#
NumPy |
brainpy.math |
jax.numpy |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
- |
- |
|
|
- |
|
|
- |
|
- |
- |
|
- |
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
- |
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
- |
|
|
|
|
|
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
- |
- |
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
- |
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
- |
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
- |
|
|
|
|
|
|
- |
|
- |
- |
|
- |
- |
|
- |
- |
- |
|
- |
|
- |
- |
|
- |
- |
|
- |
- |
- |
|
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
- |
|
- |
Summary
Number of NumPy functions: 399
Number of functions covered by
brainpy.math
: 338Number of functions unique in
brainpy.math
: 33Number of functions covered by
jax.numpy
: 314
Linear Algebra#
NumPy |
brainpy.math |
jax.numpy |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Summary
Number of NumPy functions: 19
Number of functions covered by
brainpy.math
: 19Number of functions unique in
brainpy.math
: 0Number of functions covered by
jax.numpy
: 19
Discrete Fourier Transform#
NumPy |
brainpy.math |
jax.numpy |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Summary
Number of NumPy functions: 18
Number of functions covered by
brainpy.math
: 18Number of functions unique in
brainpy.math
: 0Number of functions covered by
jax.numpy
: 18
Random Sampling#
NumPy |
brainpy.math |
jax.numpy |
---|---|---|
|
|
|
|
|
- |
|
- |
- |
|
|
- |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
- |
|
|
|
|
|
- |
|
- |
- |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
- |
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
- |
- |
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
|
|
|
|
- |
|
|
- |
|
|
- |
|
|
- |
- |
|
|
- |
|
- |
- |
|
|
- |
|
- |
- |
|
- |
- |
|
|
- |
|
|
- |
|
- |
- |
|
|
- |
|
|
- |
|
|
- |
|
|
- |
|
- |
Summary
Number of NumPy functions: 51
Number of functions covered by
brainpy.math
: 48Number of functions unique in
brainpy.math
: 13Number of functions covered by
jax.numpy
: 16
Setting#
|
|
|
Changes platform to CPU, GPU, or TPU. |
By default, XLA considers all CPU cores as one device. |
|
|
Set the numerical integrator precision. |
|
Get the numerical integrator precision. |
|
|
alias of |
|
alias of |
|
alias of |
brainpy.math.compat
module#
Optimizers#
|
SGD optimizer. |
|
Momentum optimizer. |
|
MomentumNesterov optimizer. |
|
Adagrad optimizer. |
|
Adadelta optimizer. |
|
RMSProp optimizer. |
|
Adam optimizer. |
|
Constant scheduler. |
|
ExponentialDecay scheduler. |
|
InverseTimeDecay scheduler. |
|
PolynomialDecay scheduler. |
|
PiecewiseConstant scheduler. |
Losses#
|
Cross entropy loss. |
|
L1 loss. |
|
L2 loss. |
|
L2 normal. |
|
Huber loss. |
|
mean absolute error loss. |
|
Mean squared error loss. |
|
Mean squared log error loss. |
brainpy.dyn
module#
Dynamics simulation module.
Base Class#
|
Base Dynamical System class. |
|
Container object which is designed to add other instances of DynamicalSystem. |
|
Base class to model network objects, an alias of Container. |
|
Class used to model constant delay variables. |
|
Base class to model neuronal groups. |
|
Base class to model conductance-based neuron group. |
|
Base class to model two-end synaptic connections. |
|
Abstract channel model. |
|
Channel Models#
Base Class#
|
Base class for ions. |
|
Base class for ion channels. |
Sodium Channel Models#
|
The sodium current model. |
|
Potassium Channel Models#
|
Base class for potassium channel. |
|
The delayed rectifier potassium channel current. |
|
Calcium Channel Models#
|
The base calcium dynamics. |
|
Fixed Calcium dynamics. |
|
Dynamical Calcium model. |
|
The first-order calcium concentration model. |
|
Base class for Calcium ion channels. |
|
The calcium-dependent potassium current model. |
|
The calcium-activated non-selective cation channel model. |
|
The low-threshold T-type calcium current model. |
|
The low-threshold T-type calcium current model in thalamic reticular nucleus. |
|
The high-threshold T-type calcium current model. |
|
The L-type calcium channel model. |
Ih Channel Models#
|
Base class for Ih channel models. |
|
The hyperpolarization-activated cation current model. |
Leaky Channel Models#
|
Base class for leaky channel. |
|
The leakage channel current. |
|
The potassium leak channel current. |
Neuron Models#
Biological Models#
|
Hodgkin–Huxley neuron model. |
|
The Morris-Lecar neuron model. |
|
The Pinsky and Rinsel (1994) model. |
|
Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. |
Fractional-order Models#
|
Fractional-order neuron model. |
|
The fractional-order FH-R model [1]_. |
|
Fractional-order Izhikevich model [10]_. |
Reduced Models#
|
Leaky integrate-and-fire neuron model. |
|
Exponential integrate-and-fire neuron model. |
|
Adaptive exponential integrate-and-fire neuron model. |
|
Quadratic Integrate-and-Fire neuron model. |
|
Adaptive quadratic integrate-and-fire neuron model. |
|
Generalized Integrate-and-Fire model. |
|
The Izhikevich neuron model. |
|
Hindmarsh-Rose neuron model. |
|
FitzHugh-Nagumo neuron model. |
Synapse Models#
Biological Models#
|
AMPA conductance-based synapse model. |
|
GABAa conductance-based synapse model. |
Abstract Models#
|
Voltage Jump Synapse Model, or alias of Delta Synapse Model. |
|
Current-based exponential decay synapse model. |
|
Conductance-based exponential decay synapse model. |
|
Current-based dual exponential synapse model. |
|
Conductance-based dual exponential synapse model. |
|
Current-based alpha synapse model. |
|
Conductance-based alpha synapse model. |
|
Conductance-based NMDA synapse model. |
Learning Rule Models#
|
Short-term plasticity model. |
Rate Models#
Population Models#
|
|
|
FitzHugh-Nagumo system used in [1]_. |
|
FitzHugh-Nagumo model with recurrent neural feedback. |
|
A mean-field model of a quadratic integrate-and-fire neuron population. |
|
Stuart-Landau model with Hopf bifurcation. |
|
Wilson-Cowan population model. |
|
A threshold linear rate model. |
Coupling Models#
|
Delay coupling. |
|
Diffusive coupling. |
|
Additive coupling. |
Helper Models#
Noise Models#
|
The Ornstein–Uhlenbeck process. |
Input Models#
|
Spike Time Input. |
|
The input neuron group characterized by spikes emitting at given times. |
|
Poisson Group Input. |
|
Poisson Neuron Group. |
Runners#
|
The runner for dynamical systems. |
|
The runner provides convenient interface for debugging. |
|
The runner with the structural for-loop. |
brainpy.nn
module#
Neural Networks (nn)
Base Classes#
This module provide basic Node class for whole brainpy.nn
system.
brainpy.nn.Node
: The fundamental class representing the node or the element.brainpy.nn.RecurrentNode
: The recurrent node which has a self-connection.brainpy.nn.Network
: The network model which is composed of multiple node elements. Once the Network instance receives a node operation, the wrapped elements, the new elements, and their connection edges will be formed as another Network instance. This meansbrainpy.nn.Network
is only used to pack element nodes. It will be never be an element node.brainpy.nn.FrozenNetwork
: The whole network which can be represented as a basic elementary node when composing a larger network (TODO).
|
Basic Node class for neural network building in BrainPy. |
|
Basic Network class for neural network building in BrainPy. |
|
Basic class for recurrent node. |
|
A FrozenNetwork is a Network that can not be linked to other nodes or networks. |
Node Operations#
This module provides basic operations for constructing node graphs.
It supports the following operations:
feedforward connection: “>>”, “>>=”
feedback connection: “<<”, “<<=”
merge two nodes: “&”, “&=”
select subsets of one node: “[:]”
concatenate a sequence of nodes: “[node1, node2, …]”, “(node1, node2, …)”
wrap a set of nodes: “{node1, node2, …}”
However, all operations should satisfy the following assumptions:
Feedback connection of (node1, node2) should have a feedforward path from node2 to node1.
Feedforward or feedback connections cannot generate a cycle.
Cannot concatenate multiple receiver nodes, e.g., a >> [b, c] is forbidden, but a >> {b, c} is allowed.
|
Connect two sequences of |
|
Create a feedback connection from |
|
Merge different |
|
|
|
Node Graph Tools#
This module provides basic tool for graphs, including
detect the senders and receivers in the network graph,
find input and output nodes in a given graph,
detect the cycle in the graph,
detect the path between two nodes.
|
Find all senders and receivers in the given graph. |
|
Find input nodes and output nodes. |
|
Detect whether a cycle exists in the defined graph. |
|
Detect whether there is a path exist in the defined graph from |
Runners and Trainers#
This module provides various running and training algorithms for various neural networks.
The supported training algorithms include
offline training methods, like ridge regression, linear regression, etc.
online training methods, like recursive least squares (RLS, or Force Learning), least mean squares (LMS), etc.
back-propagation learning method
and others
The supported neural networks include
reservoir computing networks,
artificial recurrent neural networks,
and others.
Base RNN Runner#
|
Structural Runner for Recurrent Neural Networks. |
Base RNN Trainer#
|
Structural Trainer for Models with Recurrent Dynamics. |
Online RNN Trainer#
|
Online trainer for models with recurrent dynamics. |
|
Force learning. |
Offline RNN Trainer#
|
Offline trainer for models with recurrent dynamics. |
|
Trainer of ridge regression, also known as regression with Tikhonov regularization. |
Back-propagation Trainer#
|
The trainer implementing back propagation through time (BPTT) algorithm for recurrent neural networks. |
|
The trainer implementing back propagation algorithm for feedforward neural networks. |
Training Algorithms#
Online Training Algorithms#
Get all supported online training methods. |
|
|
Register a new oneline learning method. |
|
Base class for online training algorithm. |
|
|
|
The recursive least squares (RLS). |
|
The least mean squares (LMS). |
Offline Training Algorithms#
Get all supported offline training methods. |
|
|
Register a new offline learning method. |
|
Base class for offline training algorithm. |
|
Training algorithm of ridge regression. |
|
Training algorithm of least-square regression. |
Data Types#
|
Base class for data type. |
Pass the only one data into the node. |
|
|
Pass a list/tuple of data into the node. |
Nodes: basic#
|
Activation node. |
|
A linear transformation applied over the last dimension of the input. |
|
A linear transformation. |
|
The input node. |
|
Concatenate multiple inputs into one. |
|
Select a subset of the given input. |
|
Reshape the input tensor to another tensor. |
|
Sum all input tensors into one. |
Nodes: artificial neural network#
Artificial neural network (ANN) nodes
|
Applies a convolution to the inputs. |
|
|
|
|
|
|
|
A layer that stochastically ignores a subset of inputs each training step. |
|
Basic fully-connected RNN core. |
|
Gated Recurrent Unit. |
|
Long short-term memory (LSTM) RNN core. |
|
|
|
Pools the input by taking the maximum over a window. |
|
Pools the input by taking the average over a window. |
|
Pools the input by taking the minimum over a window. |
|
Batch Normalization node. |
|
1-D batch normalization. |
|
2-D batch normalization. |
|
3-D batch normalization. |
|
Group normalization layer. |
|
Layer normalization (https://arxiv.org/abs/1607.06450). |
|
Instance normalization layer. |
Nodes: reservoir computing#
Reservoir computing (RC) nodes
|
Linear readout node. |
|
Nonlinear vector auto-regression (NVAR) node. |
|
Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. |
brainpy.analysis
module#
This module provides analysis tools for differential equations.
The
symbolic
module use SymPy symbolic inference to make analysis of low-dimensional dynamical system (only sypport ODEs).The
numeric
module use numerical optimization function to make analysis of high-dimensional dynamical system (support ODEs and discrete systems).The
continuation
module is the analysis package with numerical continuation methods.Moreover, we provide several useful functions in
stability
module which may help your dynamical system analysis.
Details in the following.
Low-dimensional Analyzers#
|
Bifurcation analysis of 1D system. |
|
Bifurcation analysis of 2D system. |
|
|
|
|
|
Phase plane analyzer for 1D dynamical system. |
|
Phase plane analyzer for 2D dynamical system. |
High-dimensional Analyzers#
|
Find fixed/slow points by numerical optimization. |
Stability Analysis#
Get the stability types of 1D system. |
|
Get the stability types of 2D system. |
|
Get the stability types of 3D system. |
|
|
Stability analysis of fixed points for low-dimensional system. |
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2). |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
brainpy.integrators
module#
This module provides numerical solvers for various differential equations, including:
ordinary differential equations (ODEs)
stochastic differential equations (SDEs)
fractional differential equations (FDEs)
delay differential equations (DDEs)
Details please see the following.
Integrator Runner#
|
Structural runner for numerical integrators in brainpy. |
Joint Equation#
|
Make a joint equation from multiple derivation functions. |
Numerical Methods for ODEs#
Numerical methods for ordinary differential equations (ODEs).
Base Integrator#
|
Numerical Integrator for Ordinary Differential Equations (ODEs). |
Generic Functions#
|
Numerical integration for ODEs. |
|
Set the default ODE numerical integrator method for differential equations. |
Get the default ODE numerical integrator method. |
|
|
Register a new ODE integrator. |
Get all supported numerical methods for DDEs. |
Explicit Runge-Kutta Methods#
This module provides explicit Runge-Kutta methods for ODEs.
Given an initial value problem specified as:
Let the step-size \(h > 0\).
Then, the general schema of explicit Runge–Kutta methods is 1:
where
To specify a particular method, one needs to provide the integer \(s\) (the number of stages), and the coefficients \(a_{ij}\) (for \(1 \le j < i \le s\)), \(b_i\) (for \(i = 1, 2, \cdots, s\)) and \(c_i\) (for \(i = 2, 3, \cdots, s\)).
The matrix \([a_{ij}]\) is called the Runge–Kutta matrix, while the \(b_i\) and \(c_i\) are known as the weights and the nodes. These data are usually arranged in a mnemonic device, known as a Butcher tableau (named after John C. Butcher):
A Taylor series expansion shows that the Runge–Kutta method is consistent if and only if
Another popular condition for determining coefficients is:
More details please see references 2 3 4.
- 1
Press, W. H., B. P. Flannery, S. A. Teukolsky, and W. T. Vetterling. “Section 17.1 Runge-Kutta Method.” Numerical Recipes: The Art of Scientific Computing (2007).
- 2
- 3
Butcher, John Charles. Numerical methods for ordinary differential equations. John Wiley & Sons, 2016.
- 4
Iserles, A., 2009. A first course in the numerical analysis of differential equations (No. 44). Cambridge university press.
|
Explicit Runge–Kutta methods for ordinary differential equation. |
|
The Euler method for ODEs. |
|
Explicit midpoint method for ODEs. |
|
Heun's method for ODEs. |
|
Ralston's method for ODEs. |
|
Generic second order Runge-Kutta method for ODEs. |
|
Classical third-order Runge-Kutta method for ODEs. |
|
Heun's third-order method for ODEs. |
|
Ralston's third-order method for ODEs. |
|
Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). |
|
Classical fourth-order Runge-Kutta method for ODEs. |
|
Ralston's fourth-order method for ODEs. |
|
3/8-rule fourth-order method for ODEs. |
Adaptive Runge-Kutta Methods#
This module provides adaptive Runge-Kutta methods for ODEs.
Adaptive methods are designed to produce an estimate of the local truncation error of a single Runge–Kutta step. This is done by having two methods, one with order \(p\) and one with order \(p-1\). These methods are interwoven, i.e., they have common intermediate steps. Thanks to this, estimating the error has little or negligible computational cost compared to a step with the higher-order method.
During the integration, the step size is adapted such that the estimated error stays below a user-defined threshold: If the error is too high, a step is repeated with a lower step size; if the error is much smaller, the step size is increased to save time. This results in an (almost) optimal step size, which saves computation time. Moreover, the user does not have to spend time on finding an appropriate step size.
The lower-order step is given by
where \(k_{i}\) are the same as for the higher-order method. Then the error is
which is (\(O(h^{p}\)).
The Butcher tableau for this kind of method is extended to give the values of \(b_{i}^{*}\):
More details please check 1 2 3.
- 1
- 2
Press, W.H., Press, W.H., Flannery, B.P., Teukolsky, S.A., Vetterling, W.T., Flannery, B.P. and Vetterling, W.T., 1989. Numerical recipes in Pascal: the art of scientific computing (Vol. 1). Cambridge university press.
- 3
Press, W. H., & Teukolsky, S. A. (1992). Adaptive Stepsize Runge‐Kutta Integration. Computers in Physics, 6(2), 188-191.
|
Adaptive Runge-Kutta method for ordinary differential equations. |
|
The Fehlberg RK1(2) method for ODEs. |
|
The Runge–Kutta–Fehlberg method for ODEs. |
|
The Dormand–Prince method for ODEs. |
|
The Cash–Karp method for ODEs. |
|
The Bogacki–Shampine method for ODEs. |
|
The Heun–Euler method for ODEs. |
Exponential Integrators#
This module provides exponential integrators for ODEs.
Exponential integrators are a large class of methods from numerical analysis is based on the exact integration of the linear part of the initial value problem. Because the linear part is integrated exactly, this can help to mitigate the stiffness of a differential equation.
We consider initial value problems of the form,
which can be decomposed of
where \(L={\frac {\partial f}{\partial u}}\) (the Jacobian of f) is composed of linear terms, and \(N=f(u)-Lu\) is composed of the non-linear terms.
This procedure enjoys the advantage, in each step, that \({\frac {\partial N_{n}}{\partial u}}(u_{n})=0\). This considerably simplifies the derivation of the order conditions and improves the stability when integrating the nonlinearity \(N(u(t))\).
Exact integration of this problem from time 0 to a later time \(t\) can be performed using matrix exponentials to define an integral equation for the exact solution:
This representation of the exact solution is also called as variation-of-constant formula. In the case of \(N\equiv 0\), this formulation is the exact solution to the linear differential equation.
Exponential Rosenbrock methods
Exponential Rosenbrock methods were shown to be very efficient in solving large systems of stiff ODEs. Applying the variation-of-constants formula gives the exact solution at time \(t_{n+1}\) with the numerical solution \(u_n\) as
where \(h_n=t_{n+1}-t_n\).
The idea now is to approximate the integral in (1) by some quadrature rule with nodes \(c_{i}\) and weights \(b_{i}(h_{n}L)\) (\(1\leq i\leq s\)). This yields the following class of s-stage explicit exponential Rosenbrock methods:
where \(U_{ni}\approx u(t_{n}+c_{i}h_{n})\).
The coefficients \(a_{ij}(z),b_{i}(z)\) are usually chosen as linear combinations of the entire functions \(\varphi _{k}(c_{i}z),\varphi _{k}(z)\), respectively, where
By introducing the difference \(D_{ni}=N(U_{ni})-N(u_{n})\), they can be reformulated in a more efficient way for implementation as
where \(\varphi_{1}(z)=\frac{e^z-1}{z}\).
In order to implement this scheme with adaptive step size, one can consider, for the purpose of local error estimation, the following embedded methods
which use the same stages \(U_{ni}\) but with weights \({\bar {b}}_{i}\).
For convenience, the coefficients of the explicit exponential Rosenbrock methods together with their embedded methods can be represented by using the so-called reduced Butcher tableau as follows:
- 1
- 2
Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.
|
Exponential Euler method using automatic differentiation. |
Error Analysis of Numerical Methods#
In order to identify the essential properties of numerical methods, we define basic notions 1.
For the given ODE system
we define \(y(t_n)\) as the solution of IVP evaluated at \(t=t_n\), and \(y_n\) is a numerical approximation of \(y(t_n)\) at the same location by a generic explicit numerical scheme (no matter explicit, implicit or multi-step scheme):
where \(h\) is the discretization step for \(t\), i.e., \(h=t_{n+1}-t_n\), and \(\phi(t_n,y_n,h)\) is the increment function. We say that the defined numerical scheme is consistent if \(\lim_{h\to0} \phi(t,y,h) = \phi(t,y,0) = f(t,y)\).
Then, the approximation error is defined as
The absolute error is defined as
The relative error is defined as
The exact differential operator is defined as
The approximate differential operator is defined as
Finally, the local truncation error (LTE) is defined as
In practice, the evaluation of the exact solution for different \(t\) around \(t_n\) (required by \(L_a\)) is performed using a Taylor series expansion.
Finally, we can state that a scheme is \(p\)-th order accurate by examining its LTE and observing its leading term
where \(C\) is a constant, independent of \(h\), and \(H.O.T.\) are the higher order terms of the LTE.
Example: LTE for Euler’s scheme
Consider the IVP defined by \(y' = \lambda y\), with initial condition \(y(0)=1\).
The approximation operator for Euler’s scheme is
then the LTE can be computed by
where we assume \(y_n = y(t_n)\).
Numerical Methods for SDEs#
Numerical methods for stochastic differential equations.
Base Integrator#
|
SDE Integrator. |
Generic Functions#
|
Numerical integration for SDEs. |
|
Set the default SDE numerical integrator method for differential equations. |
Get the default SDE numerical integrator method. |
|
|
Register a new SDE integrator. |
Get all supported numerical methods for DDEs. |
Normal Methods#
|
|
|
|
|
|
|
First order, explicit exponential Euler method. |
SRK methods for scalar Wiener process#
|
Order 2.0 weak SRK methods for SDEs with scalar Wiener process. |
|
Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. |
|
Numerical Methods for FDEs#
Numerical methods for stochastic differential equations.
Base Integrator#
|
Numerical integrator for fractional differential equations (FEDs). |
Generic Functions#
|
Numerical integration for FDEs. |
|
Set the default ODE numerical integrator method for differential equations. |
Get the default ODE numerical integrator method. |
|
|
Register a new ODE integrator. |
Get all supported numerical methods for DDEs. |
Methods for Caputo Fractional Derivative#
This module provides numerical methods for integrating Caputo fractional derivative equations.
|
One-step Euler method for Caputo fractional differential equations. |
|
The L1 scheme method for the numerical approximation of the Caputo fractional-order derivative equations [3]_. |
Methods for Riemann-Liouville Fractional Derivative#
This module provides numerical solvers for Grünwald–Letnikov derivative FDEs.
|
Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_. |
brainpy.datasets
module#
Chaotic Systems#
|
The Hénon map time series. |
|
The logistic map time series. |
|
Modified Lu Chen attractor. |
|
The Mackey-Glass time series. |
|
Rabinovich-Fabrikant equations. |
|
Chen attractor. |
|
Lu Chen attractor. |
|
Chua’s system. |
|
Modified Chua chaotic attractor. |
|
The Lorenz system. |
|
Modified Lorenz chaotic system. |
|
Double-scroll electronic circuit attractor. |
|
PWL Duffing chaotic attractor. |
brainpy.inputs
module#
This module provides various methods to form current inputs.
You can access them through brainpy.inputs.XXX
.
|
Format an input current with different sections. |
|
Format constant input in durations. |
|
Format constant input in durations. |
|
Format current input like a series of short-time spikes. |
|
Format current input like a series of short-time spikes. |
|
Get the gradually changed input current. |
|
Get the gradually changed input current. |
|
Stimulus sampled from a Wiener process, i.e. drawn from standard normal distribution N(0, sqrt(dt)). |
|
Ornstein–Uhlenbeck input. |
|
Sinusoidal input. |
|
Oscillatory square input. |
brainpy.connect
module#
This module provides methods to construct connectivity between neuron groups.
You can access them through brainpy.connect.XXX
.
Base Class#
|
Set the default dtype. |
|
Convert csr to csc. |
|
convert (indices, indptr) to a dense matrix. |
|
convert a dense matrix to (indices, indptr). |
|
convert pre_ids, post_ids to (indices, indptr). |
alias of |
|
alias of |
|
Base Synaptic Connector Class. |
|
Synaptic connector to build synapse connections between two neuron groups. |
|
Synaptic connector to build synapse connections within a population of neurons. |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str |
|
Built-in mutable sequence. |
Custom Connections#
|
Connector built from the dense connection matrix. |
|
Connector built from the |
|
Connector built from the sparse connection matrix |
Random Connections#
|
Connect the post-synaptic neurons with fixed probability. |
|
Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron. |
|
Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron. |
|
Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function. |
|
Build a Watts–Strogatz small-world graph. |
|
Build a random graph according to the Barabási–Albert preferential attachment model. |
|
Build a random graph according to the dual Barabási–Albert preferential attachment model. |
|
Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering. |
Regular Connections#
|
Connect two neuron groups one by one. |
|
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. |
|
The nearest four neighbors conn method. |
|
The nearest eight neighbors conn method. |
|
The nearest (2*N+1) * (2*N+1) neighbors conn method. |
Connect two neuron groups one by one. |
|
Connect each neuron in first group to all neurons in the post-synaptic neuron groups. |
|
The nearest four neighbors conn method. |
|
The nearest eight neighbors conn method. |
brainpy.initialize
module#
This module provides methods to initialize weights.
You can access them through brainpy.init.XXX
.
Base Class#
Base Initialization Class. |
|
The superclass of Initializers that initialize the weights between two layers. |
|
The superclass of Initializers that initialize the weights within a layer. |
Regular Initializers#
|
Zero initializer. |
|
One initializer. |
|
Returns the identity matrix. |
Random Initializers#
|
Initialize weights with normal distribution. |
|
Initialize weights with uniform distribution. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Construct an initializer for uniformly distributed orthogonal matrices. |
|
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. |
Decay Initializers#
|
Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function. |
|
Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. |
brainpy.losses
module#
This module implements several loss functions.
|
This criterion combines |
|
Creates a criterion that measures the mean absolute error (MAE) between each element in the logits \(x\) and targets \(y\). |
|
Computes the L2 loss. |
|
Computes the L2 loss. |
|
Huber loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
brainpy.optimizers
module#
Optimizers#
|
Base Optimizer Class. |
|
Stochastic gradient descent optimizer. |
|
Momentum optimizer. |
|
Nesterov accelerated gradient optimizer [2]_. |
|
Optimizer that implements the Adagrad algorithm. |
|
Optimizer that implements the Adadelta algorithm. |
|
Optimizer that implements the RMSprop algorithm. |
|
Optimizer that implements the Adam algorithm. |
|
Layer-wise adaptive rate scaling (LARS) optimizer. |
Schedulers#
|
|
|
The learning rate scheduler. |
|
|
|
|
|
|
|
|
|
brainpy.measure
module#
This module aims to provide commonly used analysis methods for simulated neuronal data.
You can access them through brainpy.measure.XXX
.
|
Calculate cross correlation index between neurons. |
|
Pearson correlation of the lower triagonal of two matrices. |
|
Functional connectivity matrix of timeseries activities. |
|
Get spike raster plot which displays the spiking activity of a group of neurons over time. |
|
Calculate the mean firing rate over in a neuron group. |
|
Calculate neuronal synchronization via voltage variance. |
|
Weighted Pearson correlation of two data series. |
brainpy.running
module#
This module provides APIs for brain simulations.
Monitors#
|
The basic Monitor class to store the past variable trajectories. |
Parallel Pool#
|
Run multiple models in multi-processes. |
|
Run multiple models in multi-processes with lock. |
Runners#
|
Base Runner. |
brainpy.tools
module#
Type Checking#
|
|
|
Check whether the given shapes are broadcastable. |
|
Check whether two shapes are compatible except the batch size axis. |
|
|
|
Check the dictionary data. |
|
Check the initializer. |
|
Check the connector. |
|
Check float type. |
|
Check integer type. |
|
Check string type. |
|
Code Tools#
|
|
|
|
|
Return all the identifiers in a given string |
|
|
|
|
|
Applies a dict of word substitutions. |
|
Check whether the function is a |
|
Get the main function _code string. |
|
|
|
Error Tools#
|
Check errors in a jit function. |
Other Tools#
|
|
|
|
|
Add a timeout parameter to a function and return it. |
|
Setup a progress bar. |
|
|
|
Python dictionaries with advanced dot notation access. |
brainpy.compat
module#
Brain Objects#
|
Dynamical System. |
|
Container. |
|
Network. |
|
Constant Delay. |
|
Neuron group. |
|
Two-end synaptic connection. |
Integrators#
|
Set default ode integrator. |
|
Set default sde integrator. |
Get default ode integrator. |
|
Get default sde integrator. |
Layers#
|
Basic module class. |
Models#
|
LIF neuron model. |
|
AdExIF neuron model. |
|
Izhikevich neuron model. |
|
ExpCOBA synapse model. |
|
ExpCUBA synapse model. |
|
Delta synapse model. |
Runners#
|
Integrator runner class. |
|
Dynamical system runner class. |
|
Dynamical system runner class. |
|
Dynamical system runner class. |
Monitor#
|
Monitor class. |
Release notes (brainpy)#
brainpy 2.x (LTS)#
Version 2.1.11 (2022.05.15)#
What’s Changed#
update apis, test and docs of numpy ops by @chaoming0625 in #202
update control flow, integrators, operators, and docs by @chaoming0625 in #205
improve oo-to-function transformation speed by @chaoming0625 in #208
Full Changelog: V2.1.10…V2.1.11
Version 2.1.10 (2022.05.05)#
What’s Changed#
update control flow APIs and Docs by @chaoming0625 in #192
doc: update docs of dynamics simulation by @chaoming0625 in #193
fix #125: add channel models and two-compartment Pinsky-Rinzel model by @chaoming0625 in #194
JIT errors do not change Variable values by @chaoming0625 in #195
Functionalinaty improvements by @chaoming0625 in #197
update rate docs by @chaoming0625 in #198
update brainpy.dyn doc by @chaoming0625 in #199
Full Changelog: V2.1.8…V2.1.10
Version 2.1.8 (2022.04.26)#
What’s Changed#
Fix #120 by @chaoming0625 in #178
feat: brainpy.Collector supports addition and subtraction by @chaoming0625 in #179
feat: delay variables support “indices” and “reset()” function by @chaoming0625 in #180
Support reset functions in neuron and synapse models by @chaoming0625 in #181
update()
function on longer need_t
and_dt
by @chaoming0625 in #183small updates by @chaoming0625 in #188
feat: easier control flows with
brainpy.math.ifelse
by @chaoming0625 in #189feat: update delay couplings of
DiffusiveCoupling
andAdditiveCouping
by @chaoming0625 in #190update version and changelog by @chaoming0625 in #191
Full Changelog: V2.1.7…V2.1.8
Version 2.1.7 (2022.04.22)#
What’s Changed#
synapse models support heterogeneuos weights by @chaoming0625 in #170
more efficient synapse implementation by @chaoming0625 in #171
fix input models in brainpy.dyn by @chaoming0625 in #172
update README: ‘brain-py’ to ‘brainpy’ by @chaoming0625 in #174
fix: fix the updating rules in the STP model by @c-xy17 in #176
Updates and fixes by @chaoming0625 in #177
Full Changelog: V2.1.5…V2.1.7
Version 2.1.5 (2022.04.18)#
What’s Changed#
brainpy.math.random.shuffle
is numpy like by @chaoming0625 in #153update LICENSE by @chaoming0625 in #155
compatible apis of ‘brainpy.math’ with those of ‘jax.numpy’ in most modules by @chaoming0625 in #156
Important updates by @chaoming0625 in #157
Updates by @chaoming0625 in #159
Add LayerNorm, GroupNorm, and InstanceNorm as nn_nodes in normalization.py by @c-xy17 in #162
update setup.py by @chaoming0625 in #165
update synapses by @chaoming0625 in #167
get the deserved name: brainpy by @chaoming0625 in #168
update tests by @chaoming0625 in #169
Full Changelog: V2.1.4…V2.1.5
Version 2.1.4 (2022.04.04)#
What’s Changed#
fix doc parsing bug by @chaoming0625 in #127
Reorganization of
brainpylib.custom_op
and adding interface inbrainpy.math
by @ztqakita in #128Fix: modify
register_op
and brainpy.math interface by @ztqakita in #130new features about RNN training and delay differential equations by @chaoming0625 in #132
Fix #123: Add low-level operators docs and modify register_op by @ztqakita in #134
fix #133, support batch size training with offline algorithms by @chaoming0625 in #136
fix #84: support online training algorithms by @chaoming0625 in #137
fix: fix shape checking error by @chaoming0625 in #139
solve #131, support efficient synaptic computation for special connection types by @chaoming0625 in #140
feat: update the API and test for batch normalization by @c-xy17 in #142
Node is default trainable by @chaoming0625 in #143
Updates training apis and docs by @chaoming0625 in #145
fix: add dependencies and update version by @ztqakita in #147
update requirements by @chaoming0625 in #146
data pass of the Node is default SingleData by @chaoming0625 in #148
Full Changelog: V2.1.3…V2.1.4
Version 2.1.3 (2022.03.27)#
This release improves the functionality and usability of BrainPy. Core changes include
support customization of low-level operators by using Numba
fix bugs
What’s Changed#
Provide custom operators written in numba for jax jit by @ztqakita in #122
fix DOGDecay bugs, add more features by @chaoming0625 in #124
fix bugs by @chaoming0625 in #126
Full Changelog : V2.1.2…V2.1.3
Version 2.1.2 (2022.03.23)#
This release improves the functionality and usability of BrainPy. Core changes include
support rate-based whole-brain modeling
add more neuron models, including rate neurons/synapses
support Python 3.10
improve delays etc. APIs
What’s Changed#
fix matplotlib dependency on “brainpy.analysis” module by @chaoming0625 in #110
add py3.6 test & delete multiple macos env by @ztqakita in #112
update python version by @chaoming0625 in #114
Enhance measure/input/brainpylib by @chaoming0625 in #117
fix #105: Add customize connections docs by @ztqakita in #118
fix bugs by @chaoming0625 in #119
Whole brain modeling by @chaoming0625 in #121
Full Changelog: V2.1.1…V2.1.2
Version 2.1.1 (2022.03.18)#
This release continues to update the functionality of BrainPy. Core changes include
numerical solvers for fractional differential equations
more standard
brainpy.nn
interfaces
New Features#
- Numerical solvers for fractional differential equations
brainpy.fde.CaputoEuler
brainpy.fde.CaputoL1Schema
brainpy.fde.GLShortMemory
- Fractional neuron models
brainpy.dyn.FractionalFHR
brainpy.dyn.FractionalIzhikevich
support
shared_kwargs
in RNNTrainer and RNNRunner
Version 2.1.0 (2022.03.14)#
Highlights#
We are excited to announce the release of BrainPy 2.1.0. This release is composed of nearly 270 commits since 2.0.2, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.1.0 updates are focused on improving usability, functionality, and stability of BrainPy. Highlights of version 2.1.0 include:
New module
brainpy.dyn
for dynamics building and simulation. It is composed of many neuron models, synapse models, and others.New module
brainpy.nn
for neural network building and training. It supports to define reservoir models, artificial neural networks, ridge regression training, and back-propagation through time training.New module
brainpy.datasets
for convenient dataset construction and initialization.New module
brainpy.integrators.dde
for numerical integration of delay differential equations.Add more numpy-like operators in
brainpy.math
module.Add automatic continuous integration on Linux, Windows, and MacOS platforms.
Fully update brainpy documentation.
Fix bugs on
brainpy.analysis
andbrainpy.math.autograd
Incompatible changes#
Remove
brainpy.math.numpy
module.Remove numba requirements
Remove matplotlib requirements
Remove steps in
brainpy.dyn.DynamicalSystem
Remove travis CI
New Features#
brainpy.ddeint
for numerical integration of delay differential equations, the supported methods include:Euler
MidPoint
Heun2
Ralston2
RK2
RK3
Heun3
Ralston3
SSPRK3
RK4
Ralston4
RK4Rule38
- set default int/float/complex types
brainpy.math.set_dfloat()
brainpy.math.set_dint()
brainpy.math.set_dcomplex()
- Delay variables
brainpy.math.FixedLenDelay
brainpy.math.NeutralDelay
- Dedicated operators
brainpy.math.sparse_matmul()
More numpy-like operators
Neural network building
brainpy.nn
Dynamics model building and simulation
brainpy.dyn
Version 2.0.2 (2022.02.11)#
There are important updates by Chaoming Wang in BrainPy 2.0.2.
provide
pre2post_event_prod
operatorsupport array creation from a list/tuple of JaxArray in
brainpy.math.asarray
andbrainpy.math.array
update
brainpy.ConstantDelay
, add.latest
and.oldest
attributesadd
brainpy.IntegratorRunner
support for efficient simulation of brainpy integratorssupport auto finding of RandomState when JIT SDE integrators
fix bugs in SDE
exponential_euler
methodmove
parallel
running APIs intobrainpy.simulation
add
brainpy.math.syn2post_mean
,brainpy.math.syn2post_softmax
,brainpy.math.pre2post_mean
andbrainpy.math.pre2post_softmax
operators
Version 2.0.1 (2022.01.31)#
Today we release BrainPy 2.0.1. This release is composed of over 70 commits since 2.0.0, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.0.0 updates are focused on improving documentation and operators. Core changes include:
Improve
brainpylib
operatorsComplete documentation for programming system
Add more numpy APIs
Add
jaxfwd
in autograd moduleAnd other changes
Version 2.0.0.1 (2022.01.05)#
Add progress bar in
brainpy.StructRunner
Version 2.0.0 (2021.12.31)#
Start a new version of BrainPy.
Highlight#
We are excited to announce the release of BrainPy 2.0.0. This release is composed of over 260 commits since 1.1.7, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .
BrainPy 2.0.0 updates are focused on improving performance, usability and consistence of BrainPy.
All the computations are migrated into JAX. Model building
, simulation
, training
and analysis
are all based on JAX. Highlights of version 2.0.0 include:
brainpylib are provided to dedicated operators for brain dynamics programming
Connection APIs in
brainpy.conn
module are more efficient.Update analysis tools for low-dimensional and high-dimensional systems in
brainpy.analysis
module.Support more general Exponential Euler methods based on automatic differentiation.
Improve the usability and consistence of
brainpy.math
module.Remove JIT compilation based on Numba.
Separate brain building with brain simulation.
Incompatible changes#
remove
brainpy.math.use_backend()
remove
brainpy.math.numpy
moduleno longer support
.run()
inbrainpy.DynamicalSystem
(see New Features)remove
brainpy.analysis.PhasePlane
(see New Features)remove
brainpy.analysis.Bifurcation
(see New Features)remove
brainpy.analysis.FastSlowBifurcation
(see New Features)
New Features#
- Exponential Euler method based on automatic differentiation
brainpy.ode.ExpEulerAuto
- Numerical optimization based low-dimensional analyzers:
brainpy.analysis.PhasePlane1D
brainpy.analysis.PhasePlane2D
brainpy.analysis.Bifurcation1D
brainpy.analysis.Bifurcation2D
brainpy.analysis.FastSlow1D
brainpy.analysis.FastSlow2D
- Numerical optimization based high-dimensional analyzer:
brainpy.analysis.SlowPointFinder
- Dedicated operators in
brainpy.math
module: brainpy.math.pre2post_event_sum
brainpy.math.pre2post_sum
brainpy.math.pre2post_prod
brainpy.math.pre2post_max
brainpy.math.pre2post_min
brainpy.math.pre2syn
brainpy.math.syn2post
brainpy.math.syn2post_prod
brainpy.math.syn2post_max
brainpy.math.syn2post_min
- Dedicated operators in
- Conversion APIs in
brainpy.math
module: brainpy.math.as_device_array()
brainpy.math.as_variable()
brainpy.math.as_jaxarray()
- Conversion APIs in
- New autograd APIs in
brainpy.math
module: brainpy.math.vector_grad()
- New autograd APIs in
- Simulation runners:
brainpy.ReportRunner
brainpy.StructRunner
brainpy.NumpyRunner
- Commonly used models in
brainpy.models
module brainpy.models.LIF
brainpy.models.Izhikevich
brainpy.models.AdExIF
brainpy.models.SpikeTimeInput
brainpy.models.PoissonInput
brainpy.models.DeltaSynapse
brainpy.models.ExpCUBA
brainpy.models.ExpCOBA
brainpy.models.AMPA
brainpy.models.GABAa
- Commonly used models in
Naming cache clean:
brainpy.clear_name_cache
add safe in-place operations of
update()
method and.value
assignment for JaxArray
Documentation#
Complete tutorials for quickstart
Complete tutorials for dynamics building
Complete tutorials for dynamics simulation
Complete tutorials for dynamics training
Complete tutorials for dynamics analysis
Complete tutorials for API documentation
brainpy 1.1.x (LTS)#
If you are using brainpy==1.x
, you can find documentation, examples, and models through the following links:
Documentation: https://brainpy.readthedocs.io/en/brainpy-1.x/
Examples from papers: https://brainpy-examples.readthedocs.io/en/brainpy-1.x/
Canonical brain models: https://brainmodels.readthedocs.io/en/brainpy-1.x/
Version 1.1.7 (2021.12.13)#
fix bugs on
numpy_array()
conversion in brainpy.math.utils module
Version 1.1.5 (2021.11.17)#
API changes:
fix bugs on ndarray import in brainpy.base.function.py
convenient ‘get_param’ interface brainpy.simulation.layers
add more weight initialization methods
Doc changes:
add more examples in README
Version 1.1.4#
API changes:
add
.struct_run()
in DynamicalSystemadd
numpy_array()
conversion in brainpy.math.utils moduleadd
Adagrad
,Adadelta
,RMSProp
optimizersremove setting methods in brainpy.math.jax module
remove import jax in brainpy.__init__.py and enable jax setting, including
enable_x64()
set_platform()
set_host_device_count()
enable
b=None
as no bias in brainpy.simulation.layersset int_ and float_ as default 32 bits
remove
dtype
setting in Initializer constructor
Doc changes:
add
optimizer
in “Math Foundation”add
dynamics training
docsimprove others
Version 1.1.3#
fix bugs of JAX parallel API imports
fix bugs of post_slice structure construction
update docs
Version 1.1.2#
add
pre2syn
andsyn2post
operatorsadd verbose and check option to
Base.load_states()
fix bugs on JIT DynamicalSystem (numpy backend)
Version 1.1.1#
fix bugs on symbolic analysis: model trajectory
change absolute access in the variable saving and loading to the relative access
add UnexpectedTracerError hints in JAX transformation functions
Version 1.1.0 (2021.11.08)#
This package releases a new version of BrainPy.
Highlights of core changes:
math
module#
support numpy backend
support JAX backend
support
jit
,vmap
andpmap
on class objects on JAX backendsupport
grad
,jacobian
,hessian
on class objects on JAX backendsupport
make_loop
,make_while
, andmake_cond
on JAX backendsupport
jit
(based on numba) on class objects on numpy backendunified numpy-like ndarray operation APIs
numpy-like random sampling APIs
FFT functions
gradient descent optimizers
activation functions
loss function
backend settings
base
module#
Base
for whole Version ecosystemFunction
to wrap functionsCollector
andTensorCollector
to collect variables, integrators, nodes and others
integrators
module#
class integrators for ODE numerical methods
class integrators for SDE numerical methods
simulation
module#
support modular and composable programming
support multi-scale modeling
support large-scale modeling
support simulation on GPUs
fix bugs on
firing_rate()
remove
_i
inupdate()
function, replace_i
with_dt
, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)reimplement the
input_step
andmonitor_step
in a more intuitive waysupport to set dt in the single object level (i.e., single instance of DynamicSystem)
common used DNN layers
weight initializations
refine synaptic connections
brainpy 1.0.x#
Version 1.0.3 (2021.08.18)#
Fix bugs on
firing rate measurement
stability analysis
Version 1.0.2#
This release continues to improve the user-friendliness.
Highlights of core changes:
Remove support for Numba-CUDA backend
Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).
Add the output message of the step function running error.
More powerful support for Monitoring
More powerful support for running order scheduling
Remove unsqueeze() and squeeze() operations in
brainpy.ops
Add reshape() operation in
brainpy.ops
Improve docs for numerical solvers
Improve tests for numerical solvers
Add keywords checking in ODE numerical solvers
Add more unified operations in brainpy.ops
Support “@every” in steps and monitor functions
Fix ODE solver bugs for class bounded function
Add build phase in Monitor
Version 1.0.1#
Fix bugs
Version 1.0.0#
NEW VERSION OF BRAINPY
Change the coding style into the object-oriented programming
Systematically improve the documentation
brainpy 0.x#
Version 0.3.5#
Add ‘timeout’ in sympy solver in neuron dynamics analysis
Reconstruct and generalize phase plane analysis
Generalize the repeat mode of
Network
to different running duration between two runsUpdate benchmarks
Update detailed documentation
Version 0.3.1#
Add a more flexible way for NeuState/SynState initialization
Fix bugs of “is_multi_return”
Add “hand_overs”, “requires” and “satisfies”.
Update documentation
Auto-transform range to numba.prange
Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models
Version 0.3.0#
Computation API#
Rename “brainpy.numpy” to “brainpy.backend”
Delete “pytorch”, “tensorflow” backends
Add “numba” requirement
Add GPU support
Profile setting#
Delete “backend” profile setting, add “jit”
Core systems#
Delete “autopepe8” requirement
Delete the format code prefix
Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”
Change the “ST” declaration out of “requires”
Add “repeat” mode run in Network
Change “vector-based” to “mode” in NeuType and SynType definition
Package installation#
Remove “pypi” installation, installation now only rely on “conda”
Version 0.2.4#
API changes#
Fix bugs
Version 0.2.3#
API changes#
Add “animate_1D” in
visualization
moduleAdd “PoissonInput”, “SpikeTimeInput” and “FreqInput” in
inputs
moduleUpdate phase_portrait_analyzer.py
Models and examples#
Add CANN examples
Version 0.2.2#
API changes#
Redesign visualization
Redesign connectivity
Update docs
Version 0.2.1#
API changes#
Fix bugs in numba import
Fix bugs in numpy mode with scalar model
Version 0.2.0#
API changes#
For computation:
numpy
,numba
For model definition:
NeuType
,SynConn
For model running:
Network
,NeuGroup
,SynConn
,Runner
For numerical integration:
integrate
,Integrator
,DiffEquation
For connectivity:
One2One
,All2All
,GridFour
,grid_four
,GridEight
,grid_eight
,GridN
,FixedPostNum
,FixedPreNum
,FixedProb
,GaussianProb
,GaussianWeight
,DOG
For visualization:
plot_value
,plot_potential
,plot_raster
,animation_potential
For measurement:
cross_correlation
,voltage_fluctuation
,raster_plot
,firing_rate
For inputs:
constant_current
,spike_current
,ramp_current
.
Models and examples#
Neuron models:
HH model
,LIF model
,Izhikevich model
Synapse models:
AMPA
,GABA
,NMDA
,STP
,GapJunction
Network models:
gamma oscillation
Release notes (brainpylib)#
Version 0.0.5#
Support operator customization on GPU by
numba
Version 0.0.4#
Support operator customization on CPU by
numba
Version 0.0.3#
Support
event_sum()
operator on GPUSupport
event_prod()
operator on CPUSupport
atomic_sum()
operator on GPUSupport
atomic_prod()
operator on CPU and GPU
Version 0.0.2#
Support
event_sum()
operator on CPUSupport
event_sum2()
operator on CPUSupport
atomic_sum()
operator on CPU