BrainPy documentation#

BrainPy is a highly flexible and extensible framework targeting on the high-performance Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:

  • JIT compilation and automatic differentiation for class objects.

  • Numerical methods for ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), fractional differential equations (FDEs), etc.

  • Dynamics simulation tools for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more.

  • Dynamics training tools with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.

  • Dynamics analysis tools for differential equations, including phase plane analysis, bifurcation analysis, linearization analysis, and fixed/slow point finding.

  • And more others ……

Comprehensive examples of BrainPy please see:

The code of BrainPy is open-sourced at GitHub:

Installation#

BrainPy is designed to run cross platforms, including Windows, GNU/Linux, and OSX. It only relies on Python libraries.

Installation with pip#

You can install BrainPy from the pypi. To do so, use:

pip install brainpy

To update the BrainPy version, you can use

pip install -U brainpy

If you want to install the pre-release version (the latest development version) of BrainPy, you can use:

pip install --pre brainpy

Installation from source#

If you decide not to use conda or pip, you can install BrainPy from GitHub, or OpenI.

To do so, use:

pip install git+https://github.com/PKU-NIP-Lab/BrainPy

# or

pip install git+https://git.openi.org.cn/OpenI/BrainPy

Dependency 1: NumPy#

In order to make BrainPy work normally, users should install several dependent Python packages.

The basic function of BrainPy only relies on NumPy, which is very easy to install through pip or conda:

pip install numpy

# or

conda install numpy

Dependency 2: JAX#

BrainPy relies on JAX. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend.

Linux & MacOS#

Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of JAX for Linux and macOS systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html .

To install a CPU-only version of JAX, you can run

pip install --upgrade "jax[cpu]"

If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Alternatively, you can download the preferred release “.whl” file for jaxlib, and install it via pip:

pip install xxxx.whl

pip install jax

Warning

For m1 macOS users, you should run your python environment on Apple silicon instead of intel silicon since rosetta2 cannot translate jaxlib. One suggestion is uninstall miniconda3 and install miniforge3 for managing your python environment.

Windows#

For Windows users, JAX can be installed by the following methods:

pip install xxxx.whl

pip install jax

Other Dependency#

In order to get full supports of BrainPy, we recommend you install the following packages:

  • Numba: needed in some NumPy-based computations

pip install numba

# or

conda install numba
  • brainpylib: needed in dedicated operators

pip install brainpylib
  • matplotlib: required in some visualization functions, but now it is recommended that users explicitly import matplotlib for visualization

pip install matplotlib

# or

conda install matplotlib
  • NetworkX: needed in the visualization of network training

pip install networkx

# or

conda install networkx

Simulating a Spiking Neural Network#

@Xiaoyu Chen

Spiking neural networks (SNN) are one of the most important tools to study brian dynamcis in computational neuroscience. They simulate the biological processes of information transmission in the brain, including the change of membrane potentials, neuronal firing, and synaptic transmission. In this section, we will illustrate how to build and simulate a SNN.

Before we start, the BrainPy package should be imported:

import brainpy as bp

bp.math.set_platform('cpu')

Building an E-I balance network#

Let’s try to build a E-I balance network. The structure of a E-I balance network is as follows:

Illustration of an E-I Balance Network

A E-I balance network is composed of two neuron groups and the synaptic connections between them. Specifically, they include:

  1. a group of excitatory neurons (E),

  2. a group of inhibitory neurons (I),

  3. synaptic connections within the excitatory and inhibitory neuron groups, respectively, and

  4. the inter-connections between these two groups.

To construct the network, we need to define these components one by one. BrainPy provides plenty of handy built-in models for brain dynamic simulation. They are contained in brainpy.dyn. Let’s choose the simplest yet the most canonical neuron model, the Leaky Integrate-and-Fire (LIF) model, to build the excitatory and inhibitory neuron groups:

E = bp.dyn.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto')
I = bp.dyn.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto')

E.V[:] = bp.math.random.randn(3200) * 2 - 60.
I.V[:] = bp.math.random.randn(800) * 2 - 60.

When defining the LIF neuron group, the parameters can be tuned according to users’ need. The first parameter denotes the number of neurons. Here the ratio of excitatory and inhibitory neurons is set 4:1. V_rest denotes the resting potential, V_th denotes the firing threshold, V_reset denotes the reset value after firing, tau is the time constant, and tau_ref is the duration of the refractory period. method refers to the numerical integration method to be used in simulation.

Then the synaptic connections between these two groups can be defined as follows:

E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5., method='exp_auto')
E2I = bp.dyn.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5., method='exp_auto')
I2E = bp.dyn.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10., method='exp_auto')
I2I = bp.dyn.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7, tau=10., method='exp_auto')

Here we use the Expnential synapse model (ExpCOBA) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use bp.conn.FixedProb, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in Synaptic Connection). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model.

After defining all the components, they can be combined to form a network:

net = bp.dyn.Network(E2E, E2I, I2E, I2I, E=E, I=I)

In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (E and I) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).

We have successfully constructed an E-I balance network by using BrainPy’s biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, brainpy.dyn.Network() is a simple example of customizing a network model. Please refer to Dynamic Simulation for more information.

Running a simulation#

After build a SNN, we can use it for dynamic simulation. To run a simulation, we need first wrap the network model into a runner. Currently BrainPy provides DSRunner and ReportRunner in brainpy.dyn, which will be expanded in the Runners tutorial. Here we use DSRunner as an example:

runner = bp.dyn.DSRunner(net,
                         monitors=['E.spike', 'I.spike'],
                         inputs=[('E.input', 20.), ('I.input', 20.)],
                         dt=0.1)

To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the spike variable in the E and I LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration dt (with the default value of 0.1) can also be specified.

After creating the runner, we can run a simulation by calling the runner:

runner(100)
0.779956579208374

where the calling function receives the simulation time (usually in milliseconds) as the input and returns the time (seconds) spent on simulation. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to Just-In-Time Compilation for more details.

The simulation results are stored as NumPy arrays in the monitors, and can be visualized easily:

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4.5))

plt.subplot(121)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=False)
plt.subplot(122)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], show=True)
_images/00d3d65d7f74873093b0898d1e0f12299515176d71e865e5876b669afc0cc971.png

In the code above, brianpy.visualize contains some useful functions to visualize simulation results based on the matplotlib package. Since the simulation results are stored as NumPy arrays, users can directly use matplotlib for visualization.

Building a decision making network#

Simulating a Firing Rate Network Model#

@Chaoming Wang

Whole-brain modeling is the grand challenge of computational neuroscience. Simulating a whole-brain models with spiking neurons is still nearly impossible for normal users. However, by using rate-based neural mass models, in which each brain region is approximated to several simple variables, we can build an abstract whole-brain model. In recent years, whole-brain models can be used to address a wide range of problems. In this section, we are going to talk about how to simulate a whole-brain neural mass model with BrainPy.

import brainpy as bp
import brainpy.math as bm
from brainpy.dyn import rates

import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'plasma'

Neural mass model#

A neural mass models is a low-dimensional population model of spiking neural networks. It aims to describe the coarse grained activity of large populations of neurons and synapses. Mathematically, it is a dynamical system of non-linear ODEs. A classical neural mass model is the two dimensional Wilson–Cowan model. This model tracks the activity of an excitatory population of neurons coupled to an inhibitory population. With the augmentation of such models by more realistic forms of synaptic and network interaction they have proved especially successful in providing fits to neuro-imaging data.

Here, let’s try the Wilson-Cowan model.

wc = rates.WilsonCowanModel(2,
                            wEE=16., wIE=15., wEI=12., wII=3.,
                            E_a=1.5, I_a=1.5, E_theta=3., I_theta=3.,
                            method='exp_euler_auto')
wc.x[:] = [-0.2, 1.]
wc.y[:] = [0.0, 1.]

runner = bp.dyn.DSRunner(wc, monitors=['x', 'y'], inputs=['input', -0.5])
runner.run(10.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.x,
                       plot_ids=[0, 1], legend='e', show=True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/95f44b56e952eb0f87242449b6e083123c1d98c31c83e8549d5fe225c0d3710e.png

We can see this model at least has two stable states.

Bifurcation diagram

With the automatic analysis module in BrainPy, we can easily inspect the bifurcation digram of the model. Bifurcation diagrams can give us an overview of how different parameters of the model affect its dynamics (the details of the automatic analysis support of BrainPy please see the introduction in Analyzing a Dynamical Model and tutorials in Dynamics Analysis). In this case, we make x_ext as a bifurcation parameter, and try to see how the system behavior changes with the change of x_ext.

bf = bp.analysis.Bifurcation2D(
    wc,
    target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
    target_pars={'x_ext': [-2, 2]},
    pars_update={'y_ext': 0.},
    resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 40000 candidates
I am trying to filter out duplicate fixed points ...
	Found 579 fixed points.
I am plotting the limit cycle ...
C:\Users\adadu\miniconda3\lib\site-packages\jax\_src\numpy\lax_numpy.py:1868: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "asarray")
_images/32ce599c41ab89a934eaae26c61decb236b49aefe73fd76255dcb844db72e2a5.png _images/f3b95c242dfa1506b095ba17155fa1081fe3704bb88c04f0888c50badba2600e.png

Similarly, simulating and analyzing a rate-based FitzHugh-Nagumo model is also a piece of cake by using BrainPy.

fhn = rates.FHN(1, method='exp_auto')

bf = bp.analysis.Bifurcation2D(
    fhn,
    target_vars={'x': [-2, 2], 'y': [-2, 2]},
    target_pars={'x_ext': [0, 2]},
    pars_update={'y_ext': 0.},
    resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 20000 candidates
I am trying to filter out duplicate fixed points ...
	Found 200 fixed points.
I am plotting the limit cycle ...
C:\Users\adadu\miniconda3\lib\site-packages\jax\_src\numpy\lax_numpy.py:1868: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "asarray")
_images/7aa71d7767eb98385fda5d9d4ebfa4d3906db0504c054b5da6dc0dc4318de568.png _images/d86daf43bc60b0309c208ecbe6d434fc6c6f25d33824ef62080ef3289d003214.png

In this model, we find that when the external input x_ext has the value in [0.72, 1.4], the model will generate limit cycles. We can verify this by simulation.

runner = bp.dyn.DSRunner(fhn, monitors=['x', 'y'], inputs=['input', 1.0])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x')
bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y', show=True)
_images/f626e1218ab8e9d09a7017938c0fe8b279a1be354186c8032f88ddc17820f6dd.png

Whole-brain model#

A rate-based whole-brain model is a network model which consists of coupled brain regions. Each brain region is represented by a neural mass model which is connected to other brain regions according to the underlying network structure of the brain, also known as the connectome. In order to illustrate how to use BrainPy’s support for whole-brain modeling, here we provide a processed data in the following link:

Please download the dataset and place it in your favorite PATH.

PATH = './data/hcp.npz'

In genral, a dataset for whole-brain modeling consists of the following parts:

1. A structural connectivity matrix which captures the synaptic connection strengths between brain areas. It often derived from DTI tractography of the whole brain. The connectome is then typically parcellated in a preferred atlas (for example the AAL2 atlas) and the number of axonal fibers connecting each brain area with every other area is counted. This number serves as an indication of the synaptic coupling strengths between the areas of the brain.

2. A delay matrix which calculated from the average length of the axonal fibers connecting each brain area with another.

3. A set of functional data that can act as a target for model optimization. Resting-state fMRI offers an easy and fairly unbiased way for calibrating whole-brain models. EEG data could be used as well.

Now, let’s load the dataset.

data = bm.load(PATH)
# The structural connectivity matrix

data['Cmat'].shape
(80, 80)
# The fiber length matrix

data['Dmat'].shape
(80, 80)
# The functional data for 7 subjects

data['FCs'].shape
(7, 80, 80)

Let’s have a look what the data looks like.

fig, axs = plt.subplots(1, 3, figsize=(15,5))
fig.subplots_adjust(wspace=0.28)

im = axs[0].imshow(data['Cmat'])
axs[0].set_title("Connection matrix")
fig.colorbar(im, ax=axs[0],fraction=0.046, pad=0.04)
im = axs[1].imshow(data['Dmat'], cmap='inferno')
axs[1].set_title("Fiber length matrix")
fig.colorbar(im, ax=axs[1],fraction=0.046, pad=0.04)
im = axs[2].imshow(data['FCs'][0], cmap='inferno')
axs[2].set_title("Empirical FC of subject 1")
fig.colorbar(im, ax=axs[2],fraction=0.046, pad=0.04)
plt.show()
_images/51d426717d6909aaddcb8415548ab8e527c7b0685beb687bf566fcc25ef3ba04.png

Let’s first get the delay matrix according to the fiber length matrix, the signal transmission speed between areas, and the numerical integration step dt. Here, we assume the axonal transmission speed is 20 and the simulation time step dt=0.1 ms.

sigal_speed = 20.

# the number of the delay steps
delay_mat = data['Dmat'] / sigal_speed / bm.get_dt()
delay_mat = bm.asarray(delay_mat, dtype=bm.int_)

The connectivity matrix can be directly obtained through the structural connectivity matrix, which times a global coupling strength parameter gc. b

gc = 1.

conn_mat = bm.asarray(data['Cmat'] * gc)

# It is necessary to exclude the self-connections
bm.fill_diagonal(conn_mat, 0)

We now are ready to intantiate a whole-brain model with the neural mass model and the dataset the processed before.

class WholeBrainNet(bp.dyn.Network):
  def __init__(self, Cmat, Dmat):
    super(WholeBrainNet, self).__init__()

    self.fhn = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01,
                              name='fhn', method='exp_auto')
    self.syn = rates.DiffusiveCoupling(self.fhn.x, self.fhn.x, self.fhn.input,
                                       conn_mat=Cmat,
                                       delay_steps=Dmat.astype(bm.int_),
                                       initial_delay_data=bp.init.Uniform(0, 0.05))

  def update(self, _t, _dt):
    self.syn.update(_t, _dt)
    self.fhn.update(_t, _dt)
net = WholeBrainNet(conn_mat, delay_mat)

runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)
6.346501350402832

The simulated results can be used to estimate the functional correlation matrix.

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
fc = bp.measure.functional_connectivity(runner.mon['fhn.x'])
ax = axs[0].imshow(fc)
plt.colorbar(ax, ax=axs[0])
axs[1].plot(runner.mon.ts, runner.mon['fhn.x'][:, ::5], alpha=0.8)
plt.tight_layout()
plt.show()
_images/7054ac57a1ccd15921a3d0994fc8120646517c5eef0e5bd9c8017f5d586f9e0f.png

We can compute the element-wise Pearson correlation of the functional connectivity matrices of the simulated data to the empirical data to estimate how well the model captures the inter-areal functional correlations found in empirical resting-state recordings.

scores = [bp.measure.matrix_correlation(fc, fcemp)
          for fcemp in data['FCs']]
print("Correlation per subject:", [f"{s:.2}" for s in scores])
print("Mean FC/FC correlation: {:.2f}".format(bm.mean(bm.asarray(scores))))
Correlation per subject: ['0.58', '0.45', '0.55', '0.49', '0.54', '0.5', '0.45']
Mean FC/FC correlation: 0.51

Training a Recurrent Neural Network#

@Chaoming Wang

In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy porvides various interfaces to help users train dynamical systems.

import brainpy as bp
import brainpy.math as bm

bm.enable_x64()
bm.set_dfloat(bm.float64)

# bm.set_platform('cpu')
import matplotlib.pyplot as plt

General usage#

In BrainPy, we provide a general interface to build neural networks, supporting feedforward, recurrent, feedback connections.

Model Building#

In general, each model is treated as a node. Based on the node operations, like feedforward >>, feedback <<, etc., we can create arbitrary node graph we want. For example,


feedforward_net = data >> reservoir >> readout

create a simple network in which data first feedforward to reservoir node, then the output of reservoir is readout by a readout node. Further, if we try to create a feedback connection from readout to reservoir, we can use


feedback_net = reservoir << readout

After merging it with the previous defined feedforward_net, we can create a network with feedforward and feedback connections:


model = feedforward_net & feedback_net

Model running & training#

Moreover, BrainPy provides various interfaces for network running and training, including the commonly used Ridge Regression method, FORCE learning method, and back-progropagation through time algorithms. Users can create these runners and trainers with the following codes:


runner = bp.nn.RNNRunner(model, ...)

or,


trainer = bp.nn.RidgeTrainer(model, ...)

trainer = bp.nn.FORCELearning(model, ...)

trainer = bp.nn.BPTT(model, ...)

Bellow, we demonstrate these supports with several examples.

Echo state network#

We first illustrate the training interface of BrainPy using an echo state network.

For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.

# create the components we need

i = bp.nn.Input(3)
r = bp.nn.Reservoir(400, spectral_radius=1.4)
o = bp.nn.LinearReadout(3)
# create the model we need

model = i >> r >> o
model.plot_node_graph(fig_size=(5, 5), node_size=2000)
_images/e33df6ef1f0a0409aed4842e037cb8b87fdaa9b4e0fb20aa01be5693835c780e.png

We use this created network to predict the chaotic time series, named as Lorenz attractor. Particurlaly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.

dt = 0.01
data = bp.datasets.lorenz_series(100, dt=dt)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['x'].flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['y'].flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['z'].flatten()))
plt.ylabel('z')
plt.show()
_images/9492b5db560f9b93d40950076a5228efc4ba5a01eab3db249b6e97f4843a9771.png
def get_subset(data, start, end):
    res = {'x': data['x'][start: end],
           'y': data['y'][start: end],
           'z': data['z'][start: end]}
    res = bm.hstack([res['x'], res['y'], res['z']])
    return res.reshape((1, ) + res.shape)

To complish this task, we use Ridge Regression method to train the network. Before that, we first initialize the network with the batch size of 1, and then construct a Ridge Regression trainer.

model.initialize(num_batch=1)

trainer = bp.nn.RidgeTrainer(model, beta=1e-6)

We warm-up the network with 20 ms.

warmup_data = get_subset(data, 0, int(20/dt))

outs = trainer.predict(warmup_data)

outs.shape
(1, 2000, 3)

The training data is the time series from 20 ms to 80 ms. We want the network has the abilitty to forecast 1 time step ahead.

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)

trainer.fit([x_train, y_train])

Then we test the trained network with the next 20 ms.

x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))

predictions = trainer.predict(x_test)

bp.losses.mean_squared_error(y_test, predictions)
DeviceArray(0.00014552, dtype=float64)
def plot_difference(truths, predictions):
    truths = truths.numpy()
    predictions = predictions.numpy()

    plt.subplot(311)
    plt.plot(truths[0, :, 0], label='Ground Truth')
    plt.plot(predictions[0, :, 0], label='Prediction')
    plt.ylabel('x')
    plt.legend()
    plt.subplot(312)
    plt.plot(truths[0, :, 1], label='Ground Truth')
    plt.plot(predictions[0, :, 1], label='Prediction')
    plt.ylabel('y')
    plt.legend()
    plt.subplot(313)
    plt.plot(truths[0, :, 2], label='Ground Truth')
    plt.plot(predictions[0, :, 2], label='Prediction')
    plt.ylabel('z')
    plt.legend()
    plt.show()
plot_difference(y_test, predictions)
_images/6e7fa7779767de5c0a1cf9ba36cdda05359c1adc69cd27a5d65de815a9a6fcb0.png

We can make the task harder to forecast 10 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
_images/42801df247b1603785bab8b7a7161186b3f9b6ea600d5127bf7eb4847224e758.png

Or forecast 100 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
_images/4cd3885698d1192d95b379961f06b07d64f60618135cb5ab300fd255cdc1827d.png

As you see, forecasting larger time step makes the learning more difficult.

Next generation RC#

(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).

(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.

In BrainPy, we can easily implement this kind of network. Here, let’s try to use NG-RC to infer the \(z\) variable according to \(x\) and \(y\) variables. This task is important for applications where it is possible to obtain high-quality information about a dynamical variable in a laboratory setting, but not in field deployment.

Let’s first initialize the data we need.

dt = 0.02
t_warmup = 10.  # ms
t_train = 20.  # ms
t_test = 50.  # ms
num_warmup = int(t_warmup / dt)  # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)

lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test,
                                          dt=dt,
                                          inits={'x': 17.67715816276679,
                                                 'y': 12.931379185960404,
                                                 'z': 43.91404334248268})
def get_subset(data, start, end):
  res = {'x': data['x'][start: end],
         'y': data['y'][start: end],
         'z': data['z'][start: end]}
  X = bm.hstack([res['x'], res['y']])
  X = X.reshape((1,) + X.shape)
  Y = res['z']
  Y = Y.reshape((1, ) + Y.shape)
  return X, Y


X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup)
X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
X_test, Y_test = get_subset(lorenz_series, num_warmup + num_train, num_warmup + num_train + num_test)

The network architecture is the same with the above echo state network. Specifically, we have an input node, a reservoir node and an output node. To accomplish this task, (Gauthier, et. al., Nature Communications, 2021) used 4 delay history information with stride of 5, and their quadratic polynomial monomials. Therefore, we create the network as:

i = bp.nn.Input(2)
r = bp.nn.NVAR(delay=4, order=2, stride=5)
o = bp.nn.LinearReadout(1, trainable=True)
model = i >> r >> o
model.initialize(num_batch=1)

We train the network using the Ridge Regression method too.

trainer = bp.nn.RidgeTrainer(model, beta=0.05)

# warm-up
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))

# training
trainer.fit([X_train, Y_train])

# prediction
outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
Warmup NMS:  10729.250973138222
Prediction NMS:  0.3374043793562189
X_test = bm.asarray(X_test).numpy()[0]
Y_test = bm.asarray(Y_test).numpy().flatten()
outputs = bm.asarray(outputs).numpy().flatten()

plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(X_test[:, 0], color='b')
plt.ylabel('x')
plt.subplot(312)
plt.plot(X_test[:, 1], color='b')
plt.ylabel('y')
plt.subplot(313)
plt.plot(Y_test, color='b', label='Grund Truth')
plt.plot(outputs, color='r', label='Prediction')
plt.ylabel('y')
plt.legend()
plt.show()
_images/d0dc0f77f3a35abfece1b45bd1c2b9b035a684e20485e834d28b82759dd10650.png

Recurrent neural network#

In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.nn.BPTT method.

Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we has a time series of noise data,

noises = bm.random.normal(0, 0.2, size=10)

plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()
_images/96c874043041a60b13b1bb8f94c06844a15e1f28d80fe8bb57fe42be59fd0a48.png

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt:

dt = 0.1
integrals = bm.cumsum(noises) * dt

plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()
_images/8d7995919faafe74df6dae818cd447d8cce92b46c508f610e31f3336e997e3fc.png

Here, we first define a task which generates the input data and the target integration results.

from functools import partial

dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128


@partial(bm.jit,
         dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
         static_argnames=['batch_size'])
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
  # Create the white noise input
  sample = bm.random.normal(size=(batch_size, 1, 1))
  bias = mean * 2.0 * (sample - 0.5)
  samples = bm.random.normal(size=(batch_size, num_step, 1))
  noise_t = scale / dt ** 0.5 * samples
  inputs = bias + noise_t
  targets = bm.cumsum(inputs, axis=1)
  return inputs, targets


def train_data():
  for _ in range(100):
    yield build_inputs_and_targets(batch_size=num_batch)

Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True for the used VanillaRNN instance.

model = (
    bp.nn.Input(1)
    >>
    bp.nn.VanillaRNN(100, state_trainable=True)
    >>
    bp.nn.Dense(1)
)
model.initialize(num_batch=num_batch)

brainpy.nn.BPTT trainer receives a loss function setting, and an optimizer setting. Loss function can be selected from the brainpy.losses module, or it can be a callable function receives (predictions, targets) argument. Optimizer setting must be an instance of brainpy.optim.Optimizer.

Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.

# define loss function
def loss(predictions, targets, l2_reg=2e-4):
    mse = bp.losses.mean_squared_error(predictions, targets)
    l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
    return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.nn.BPTT(model,
                     loss=loss,
                     optimizer=opt,
                     max_grad_norm=5.0)
# train the model
trainer.fit(train_data,
            num_batch=num_batch,
            num_train=30,
            num_report=500)
Train 500 steps, use 9.3755 s, train loss 0.03093
Train 1000 steps, use 6.7661 s, train loss 0.0275
Train 1500 steps, use 6.9309 s, train loss 0.02998
Train 2000 steps, use 6.6827 s, train loss 0.02409
Train 2500 steps, use 6.6528 s, train loss 0.02289
Train 3000 steps, use 6.6663 s, train loss 0.02187

The training losses is recorded in the .train_losses attribute.

plt.figure(figsize=(8, 3))
plt.plot(trainer.train_losses.numpy())
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()
_images/1305960ae601c12c6a8f26b323640d78985071f53b63839d1cc8223e86a51c5f.png

Finally, let’s try the trained network, and test whether it can generate the correct integration results.

model.initialize(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()
_images/7648def00894cd5264a04bbca55d02895309342c87ba4ffe3c606cbffadd64b6.png

Further reading#

Analyzing a Dynamical Model#

@Xiaoyu Chen @Chaoming Wang

In BrainPy, defined models can not only be used for simulation, but also to perform automatic dynamics analysis.

BrainPy provides rich interfaces to support analysis, incluing

Here we will introduce three brief examples of 1-D bifurcation analysis and 2-D phase plane analysis. For more detailsand more examples, please refer to the tutorials of dynamics analysis.

import brainpy as bp

bp.math.set_platform('cpu')

bp.math.enable_x64()  # Dynamics analysis in BrainPy requires 64-bit computation

Example 1: bifurcation analysis of a 1D model#

Here, we demonstrate how to perform a bifurcation analysis through a one-dimensional neuron model.

Let's try to analyze how the external input influences the dynamics of the Exponential Integrate-and-Fire (ExpIF) model. The ExpIF model is a one-variable neuron model whose dynamics is defined by:

$$
\tau {\dot {V}}= - (V - V_\mathrm{rest}) + \Delta_T \exp(\frac{V - V_T}{\Delta_T}) + RI \\
\mathrm{if}\, \, V > \theta, \quad  V \gets  V_\mathrm{reset}
$$

We can analyze the change of \({\dot {V}}\) with respect to \(V\). First, let’s generate an ExpIF model using pre-defined modules in brainpy.dyn:

expif = bp.dyn.ExpIF(1, delta_T=1.)

The default value of other parameters can be accessed directly by their names:

expif.V_rest, expif.V_T, expif.R, expif.tau
(-65.0, -59.9, 1.0, 10.0)

After defining the model, we can use it for bifurcation analysis.

bif = bp.analysis.Bifurcation1D(
    model=expif,
    target_vars={'V': [-70., -55.]},
    target_pars={'I_ext': [0., 6.]},
    resolutions=0.01
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/f8783cce9c1efe954937d4582767cff50b66916bf509b995551fb5bae8896095.png

In the Bifurcation1D analyzer, model refers to the modelto be analyzed (essentially the analyzer will access the derivative function in the model), target_vars denotes the target variables, target_pars denotes the changing parameters, and resolution determines the resolutioin of the analysis.

In the image above, there are two lines that “merge” together to form a bifurcation. The dots making up the lines refer to the fixed points of \(\mathrm{d}V/\mathrm{d}t\). On the left of the bifurcation point (where two lines merge together), there are two fixed points where \(\mathrm{d}V/\mathrm{d}t = 0\) given each external input \(I_\mathrm{ext}\). One of them is a stable point, and the other is an unstable one. When \(I_\mathrm{ext}\) increases, the two fixed points move closer to each other, overlap, and finally disappear.

Bifurcation analysis provides insights for the dynamics of the model, for it indicates the number and the change of stable states with respect to different parameters.

Example 2: phase plane analysis of a 2D model#

Besides bifurcationi analysis, another important tool is phase plane analysis, which displays the trajectory of the variable point in the vector field. Let’s take the FitzHugh–Nagumo (FHN) neuron model model as an example. The dynamics of the FHN model is given by:

\[\begin{split} {\dot {v}}=v-{\frac {v^{3}}{3}}-w+I, \\ \tau {\dot {w}}=v+a-bw. \end{split}\]

Users can easily define a FHN model which is also provided by BrainPy:

fhn = bp.dyn.FHN(1)

Because there are two variables, \(v\) and \(w\), in the FHN model, we shall use 2-D phase plane analysis to visualize how these two variables change over time.

analyzer = bp.analysis.PhasePlane2D(
  model=fhn,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  pars_update={'I_ext': 0.8}, 
  resolutions=0.01,
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
analyzer.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 866 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 V=-0.2738719079879798, w=0.5329731346879486 is a unstable node.
I am plotting the trajectory ...
_images/0cb72aee26271897cf9fc265b774b05483edf3a019b4061457dc92f2c05bae82.png

In the PhasePlane2D analyzer, the parameters model, target_vars, and resolution is the same as those in Bifurcation1D. pars_update specifies the parameters to be updated during analysis. After defining the analyzer, users can visualize the nullcline, vector field, fixed points and the trajectory in the image. The phase plane gives users intuitive interpretation of the changes of \(v\) and \(w\) guided by the vector field (violet arrows).

Further reading#

  • For more details about how to perform bifurcation analysis and phase plane analysis, please see the tutorial of Low-dimensional Analyzers.

  • A good example of phase plane analysis and bifurcation analysis is the decision-making model, please see the tutorial in Analysis of a Decision-making Model

  • If you want to how to analyze the slow points (or fixed points) of your high-dimensional dynamical models, please see the tutorial of High-dimensional Analyzers

Math Basics#

brainpy.math Overview#

@Chaoming Wang @Xiaoyu Chen

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as JAX and Numba are provided in Python. While they are designed to work only on pure Python functions, most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on class in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling.

In order to provide a platform can satisfy the need for brain dynamics programming, we provide the brainpy.math module.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
import numpy as np

Why use brainpy.math?#

Specifically, brainpy.math makes the following contributions:

1. Numpy-like ndarray.#

Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y. To overcome these drawbacks, brainpy.math provides JaxArray that can be used in the same way as numpy ndarray.

# ndarray in "numpy"

a = np.arange(5)
a
array([0, 1, 2, 3, 4])
a[0] += 5
a
array([5, 1, 2, 3, 4])
# ndarray in "brainpy.math"

b = bm.arange(5)
b
JaxArray([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
JaxArray([5, 1, 2, 3, 4], dtype=int32)

For more details, please see the Tensors tutorial.

2. Numpy-like random sampling.#

JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math provides brainpy.math.random for random sampling just like the numpy.random module. For example:

# random sampling in "numpy"

np.random.seed(12345)
np.random.random(5)
array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])
np.random.normal(0., 2., 5)
array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])
# random sampling in "brainpy.math.random"

bm.random.seed(12345)
bm.random.random(5)
JaxArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ],            dtype=float32)
bm.random.normal(0., 2., 5)
JaxArray([-1.5375282, -0.5970201, -2.272839 ,  3.233081 , -0.2738593],            dtype=float32)

For more details, please see the Tensors tutorial.

3. JAX transformations on class objects.#

OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math extends JAX transformations to Python classess.

Example 1: JIT compilation performed on class objects.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr1 = LogisticRegression(num_dim)

%timeit lr1(points, labels)
255 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr2 = bm.jit(LogisticRegression(num_dim))

%timeit lr2(points, labels)
162 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Example 2: Autograd performed on variables of a class object.

class Linear(bp.Base):
  def __init__(self, num_hidden, num_input, **kwargs):
    super(Linear, self).__init__(**kwargs)

    # parameters
    self.num_input = num_input
    self.num_hidden = num_hidden

    # variables
    self.w = bm.random.random((num_input, num_hidden))
    self.b = bm.zeros((num_hidden,))

  def __call__(self, x):
    r = x @ self.w + self.b
    return r.mean()
l = Linear(num_hidden=3, num_input=2)
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))
(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
              [0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32),
 DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32))

What is the difference between brainpy.math and other frameworks?#

brainpy.math is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.

However, there are important differences between brainpy.math and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary, brainpy.math allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called Objax which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming.

Tensors and Variables#

@Xiaoyu Chen @Chaoming Wang

In this section ,we will briefly introduce two basic and important data structures: tensors and variables. They form the foundation for mathematical operations of brain dynamics programming (BDP) in BrainPy.

Tensors#

Definition and Attributes#

A tensor is a data structure that organizes algebraic objects in a multidimentional vector space. Simply speaking, in BrainPy, a tensor is a multidimensional array that contains the same type of data, most commonly of the numeric or boolean type.

The dimensions of an array are called axes. In the following illustration, the 1-D array ([7, 2, 9, 10]) only has one axis. There are 4 elements in this axis, so the shape of the array is (4,).

By contrast, the 2-D array in the illustration has 2 axes. The first axis is of length 2 and the second of length 3. Therefore, the shape of the 2-D array is (2, 3).

Similarly, the 3-D array has 3 axes, with the dimensions (4, 3, 2) in each axis, respectively.

To enable tensor operations, users should import the brainpy.math module:

import brainpy.math as bm

# bm.set_platform('cpu')
t1 = bm.array([[[0, 1, 2, 3], [1, 2, 3, 4], [4, 5, 6, 7]], 
               [[0, 0, 0, 0], [-1, 1, -1, 1], [2, -2, 2, -2]]])
t1
JaxArray([[[ 0,  1,  2,  3],
           [ 1,  2,  3,  4],
           [ 4,  5,  6,  7]],

          [[ 0,  0,  0,  0],
           [-1,  1, -1,  1],
           [ 2, -2,  2, -2]]], dtype=int32)

Here we create a 3-dimensional tensor with the shape of (2, 3, 4) and the type of int32. Tensors created by brainpy.math will be stored in JaxArray, for their future operations will be accelerated by just-in-time (JIT) compilation.

A tensor has several important attributes:

  • .ndim: the number of axes (dimensions) of the tensor.

  • .shape: the dimensions of the tensor. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, the shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

  • .size: the total number of elements of the tensor. This is equal to the product of the elements of shape.

  • .dtype: an object describing the type of the elements in the tensor. One can create or specify dtypes using standard Python types.

print('t1.ndim: {}'.format(t1.ndim))
print('t1.shape: {}'.format(t1.shape))
print('t1.size: {}'.format(t1.size))
print('t1.dtype: {}'.format(t1.dtype))
t1.ndim: 3
t1.shape: (2, 3, 4)
t1.size: 24
t1.dtype: int32

Below we will give a few examples of tensor operations that are commonly used in brain dynamics programming. For more details about tensor operations, please refer to the tensor tutorial.

Creating a tensor#
t2 = bm.arange(4)
# t2: JaxArray([0, 1, 2, 3], dtype=int32)

t3 = bm.ones((2, 4)) * 1.5
# t3: JaxArray([[1.5, 1.5, 1.5, 1.5],
#               [1.5, 1.5, 1.5, 1.5]], dtype=float32)
Tensor operations#
# indexing and slicing
t3[1]
# DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)

t3[:, 2:]
# DeviceArray([[1.5, 1.5],
#              [1.5, 1.5]], dtype=float32)
DeviceArray([[1.5, 1.5],
             [1.5, 1.5]], dtype=float32)
# algebraic operations
t2 + t3[0]
# JaxArray([1.5, 2.5, 3.5, 4.5], dtype=float32)

t3[0] / t1[0, 1]
# DeviceArray([1.5  , 0.75 , 0.5  , 0.375], dtype=float32)

# broadcasting
t2 + 3
# JaxArray([3, 4, 5, 6], dtype=int32)

t2 + t3
# JaxArray([[1.5, 2.5, 3.5, 4.5],
#           [1.5, 2.5, 3.5, 4.5]], dtype=float32)
JaxArray([[1.5, 2.5, 3.5, 4.5],
          [1.5, 2.5, 3.5, 4.5]], dtype=float32)
# some functions
bm.dot(t2, t3.T)
# JaxArray([9., 9.], dtype=float32)

bm.max(t1, axis=2)
# JaxArray([[3, 4, 7],
#           [0, 1, 2]], dtype=int32)

t3.flatten()
# JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)
JaxArray([1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32)

In BrainPy, tensors can be used to store some parameters related to dynamical models. For example, if we define a group of Integrate-and-Fire (LIF) neurons and wish to assign each neuron with a different time constant \(\tau\), then we can generate a tensor containing an array of time constants.

n = 6  # assume there are 6 LIF neurons
tau = bm.random.randn(n)*2. + 20.
tau
JaxArray([18.485964, 19.765427, 15.078529, 21.210836, 17.134335,
          21.495173], dtype=float32)

Through the code above, a group of time constants is generated from a normal distribution, with a mean of 20 and a variance of 2.

Variables#

We have talked about the definition, operations, and application of tensors in BrainPy. There are some situations, however, where tensors are not applicable. Due to JIT compilation, once a tensor is given to the JIT compiler, the values inside the tensor cannot be changed. This gives rise to severe limitations, because some properties of the dynamical system, such as the membrane potential, dynamically changes over time. Therefore, we need a new data structure to store such dynamic variables, and that is brainpy.math.Variable.

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a tensor. The tensor is stored as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.

To create or change a tensor into a variable, users just need to wrap the tensor into brainpy.math.Variable:

v = bm.Variable(t2)
v
Variable([0, 1, 2, 3], dtype=int32)

Note that the array is contained in a “Variable” instead of a “JaxArray”.

Note

Tensors that are not marked as Variables will be JIT compiled as static data. In JIT compilation, it is shown that modifications of tensors are invalid in a JIT-compilation environment.

Users can access the value in the Variable through its attribute .value:

v.value
DeviceArray([0, 1, 2, 3], dtype=int32)

Since the data inside a Variable is a tensor, common operations on tensors can be directly grafted to Variables.

In-place updating#

Though the operations are the same, there are some requirements for updating a Variable. If we directly change a Variable, The returning data will become a tensor but not a Variable.

v2 = v + 2
v2
JaxArray([2, 3, 4, 5], dtype=int32)

To update the Variable, users are required to use in-place updating, which only modifies the value inside the Variable but does not change the reference pointing to the Variable. In-place updating operations include:

1. Indexing and slicing

  • Indexing: v[i] = a

  • Slicing: v[i:j] = b

  • Slicing the specific values: v[[1, 3]] = c

  • Slicing all values, v[:] = d, v[...] = e

for more details, please refer to Array Objects Indexing.

v[0] = 10
v[1:3] = 9
v
Variable([10,  9,  9,  3], dtype=int32)

2. Augmented assignment

  • += (add)

  • -= (subtract)

  • /= (divide)

  • *= (multiply)

  • //= (floor divide)

  • %= (modulo)

  • **= (power)

  • &= (and)

  • |= (or)

  • ^= (xor)

  • <<= (left shift)

  • >>= (right shift)

v -= 3
v <<= 1
v
Variable([14, 12, 12,  0], dtype=int32)

3. .value assignment

v.value = bm.arange(4)
v
Variable([0, 1, 2, 3], dtype=int32)

.value assignment directly accesses the data stored in the JaxArray. When using .value, the new data should be of the same type and shape as the original ones.

try:
    v.value = bm.array([1., 1., 1., 0.])
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

4. .update() method

This method will also check if the new data is of the same type and shape as the original ones.

v.update(bm.array([3, 4, 5, 6]))
v
Variable([3, 4, 5, 6], dtype=int32)

For more details, such as the subtypes of Variables and more information about in-place updating, please see the advanced tutorial for Variables.

Just-In-Time Compilation#

@Chaoming Wang @Xiaoyu Chen

One of the core ideas of BrainPy is Just-In-Time (JIT) compilation. JIT compilation enables Python codes to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine-code speed, which will not only compensate for the time spent for code transformation but also save more time. Therefore, it is necessary to understand how to code in a JIT compatible environment.

This section will briefly introduce JIT compilation and its relation to BrainPy. For more details such as the JIT mechanism in BrainPy, please refer to the advanced Compilation tutorial.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

JIT Compilation for Functions#

To take advantage of the JIT compilation, users just need to wrap their customized functions or objects into bm.jit() to instruct BrainPy to transform Python code into machine code.

Take the pure functions as an example. Here we try to implement a function of Gaussian Error Linear Unit:

def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let’s first try to run the function without JIT.

x = bm.random.random(100000)
%timeit gelu(x)
295 µs ± 3.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

After JIT compilation, the function significantly speeds up.

gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
66 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

JIT Compilation for Objects#

JIT compilation for functions is not enough for brain dynamics programming, since a multitude of dynamic variables and differential equations in a large system would make computation surprisingly complicated. Therefore, BrainPy enables JIT compilation to be performed on class objects, as long as users comply with the following rules:

  1. The class object must be a subclass of brainpy.Base.

  2. Dynamically changed variables must be labeled as brainpy.math.Variable.

  3. Variable updating must be accomplished by in-place operations.

Below is a simple example of a Logistic regression classifier. When wrapped into bm.jit(), the class oject will be JIT compiled.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w.value = self.w - u

In this example, the model weights (self.w) will be modified during training, so it is marked as bm.Variable. If not, in the compilation phase, all self. accessed variables which are not the instances of bm.Variable will be compiled as static constants.

import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):
        model(points, labels)

    print(f'{name} used time {time.time() - t0} s')
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
# without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')
Logistic Regression (without jit) used time 10.024710893630981 s
# with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')
Logistic Regression (with jit) used time 5.015154838562012 s

From the above example, we can appreciate the acceleration of JIT compilation. This example, however, is too simplified to show the great difference between running with and without JIT. In fact, in a large brain model, the acceleration brought by JIT compilation is usually far more significant.

Automatic JIT Compilation in Runners#

In a large dynamical system where a large number of neurons and synapses are defined, it would be a little tedious to explicitly wrap every object into bm.jit(). Fortunately, in most conditions, users do not need to call bm.jit(), as BrainPy will make JIT compilation automatically.

BrainPy provides a brainpy.Runner class that is inherited by various runners used in simulation, traning and integration. When initializing it, a runner receives a parameter named jit, which is set True by default. This suggests that Runner will automatically JIT compile the target oject as long as it is wrapped into the runner.

For example, when users perform dynamic simulation on a HH model, they first need to wrap the model into a simulation runner:

model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.))
runner(1000)  # running 1000 ms
0.6139698028564453

Where model is wrapped into a runner, and it will be JIT compiled during simulation.

If users do not want to use JIT compilation (JIT compilation prohibits Python debugging), they can turn it of by setting jit=False:

model = bp.dyn.HH(1000)
runner = bp.DSRunner(target=model, inputs=('input', 10.), jit=False)
runner(1000)
258.76088523864746

The output is the time (s) spent on simulation. We can see that the simulation is much slower without JIT compilation.

Besides simulation, runners are also used by integrators and trainers. For more details, please refer to the tutorial of runners.

Control Flows#

@Chaoming Wang

Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.

Python has two types of control structures:

  • Selection: used for decisions and branching.

  • Repetition: used for looping, i.e., repeating a piece of code multiple times.

In this section, we are going to talk about how to build effective control flows with BrainPy and JAX.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

1. Selection#

In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:

  • if-else

  • nested if

  • if-elif-else

Non-Variable-based control statements#

Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,

class OddEven(bp.Base):
    def __init__(self, type_=1):
        super(OddEven, self).__init__()
        self.type_ = type_
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.type_ == 1:
            self.a += 1
        elif self.type_ == 2:
            self.a -= 1
        else:
            raise ValueError(f'Unknown type: {self.type_}')
        return self.a

In the above example, the target statement in if (statement) syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.

model = bm.jit(OddEven(type_=1))

model()
Variable([1.], dtype=float32)
model = bm.jit(OddEven(type_=2))

model()
Variable([-1.], dtype=float32)
try:
    model = bm.jit(OddEven(type_=3))
    model()
except ValueError as e:
    print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable-based control statements#

However, if the statement target in a if ... else ... syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.

class OddEvenCauseError(bp.Base):
    def __init__(self):
        super(OddEvenCauseError, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.rand < 0.5:  self.a += 1
        else:  self.a -= 1
        return self.a
wrong_model = bm.jit(OddEvenCauseError())

try:
    wrong_model()
except Exception as e:
    print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: This problem may be caused by several ways:
1. Your if-else conditional statement relies on instances of brainpy.math.Variable. 
2. Your if-else conditional statement relies on functional arguments which do not set in "static_argnames" when applying JIT compilation. More details please see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
3. The static variables which set in the "static_argnames" are provided as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].

To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):

brainpy.math.where#

where(condition, x, y) function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.

a = 1.
bm.where(a < 0, 0., 1.)
JaxArray(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
JaxArray([1., 0., 0., 1., 1.], dtype=float32, weak_type=True)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
JaxArray([[0., 0., 1.],
          [1., 1., 0.],
          [0., 0., 0.]], dtype=float32, weak_type=True)

For the above example, we can rewrite it by using where syntax as:

class OddEvenWhere(bp.Base):
    def __init__(self):
        super(OddEvenWhere, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.where(self.rand < 0.5, 1., -1.)
        return self.a
model = bm.jit(OddEvenWhere())
model()
Variable([-1.], dtype=float32)
brainpy.math.ifelse#

Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.

In its simplest case, brainpy.math.ifelse(condition, branches, operands, dyn_vars=None) is equivalent to:

def ifelse(condition, branches, operands, dyn_vars=None):
  true_fun, false_fun = branches
  if condition:
    return true_fun(operands)
  else:
    return false_fun(operands)

Based on this function, we can rewrite the above example by using cond syntax as:

class OddEvenCond(bp.Base):
    def __init__(self):
        super(OddEvenCond, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.ifelse(self.rand[0] < 0.5,
                            [lambda _: 1., lambda _: -1.])
        return self.a
model = bm.jit(OddEvenCond())
model()
Variable([1.], dtype=float32)

If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands, dyn_vars=None) can also help you accomplish this easily. Actually, multiple branching case is equivalent to:

def ifelse(conditions, branches, operands, dyn_vars=None):
  pred1, pred2, ... = conditions
  func1, func2, ..., funcN = branches
  if pred1:
    return func1(operands)
  elif pred2:
    return func2(operands)
  ...
  else:
    return funcN(operands)

For example, if you have the following code:

def f(a):
  if a > 10:
    return 1.
  elif a > 5:
    return 2.
  elif a > 0:
    return 3.
  elif a > -5:
    return 4.
  else:
    return 5.

It can be expressed as:

def f(a):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[1., 2., 3., 4., 5.])
f(11.)
DeviceArray(1., dtype=float32, weak_type=True)
f(6.)
DeviceArray(2., dtype=float32, weak_type=True)
f(1.)
DeviceArray(3., dtype=float32, weak_type=True)
f(-4.)
DeviceArray(4., dtype=float32, weak_type=True)
f(-6.)
DeviceArray(5., dtype=float32, weak_type=True)

A more complex example is:

def f2(a, x):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[lambda x: x*2,
                             2.,
                             lambda x: x**2 -1,
                             lambda x: x - 4.,
                             5.],
                   operands=x)
f2(11, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(6, 1.)
DeviceArray(2., dtype=float32, weak_type=True)
f2(1, 1.)
DeviceArray(0., dtype=float32, weak_type=True)
f2(-4, 1.)
DeviceArray(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
DeviceArray(5., dtype=float32, weak_type=True)

If instances of brainpy.math.Variable are used in branching functions, you can declare them in the dyn_vars argument.

a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x):  a.value += 1
def false_f(x): b.value -= 1

bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])

print('a:', a)
print('b:', b)
a: Variable([1., 1.], dtype=float32)
b: Variable([0., 0.], dtype=float32)

2. Repetition#

A repetition statement is used to repeat a group(block) of programming instructions.

In Python, we generally have two loops/repetitive statements:

  • for loop: Execute a set of statements once for each item in a sequence.

  • while loop: Execute a block of statements repeatedly until a given condition is satisfied.

Pythonic loop syntax#

Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,

class LoopSimple(bp.Base):
    def __init__(self):
        super(LoopSimple, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        for s in self.seq:
            self.res += s
        return self.res.value
import time

def measure_time(f):
    t0 = time.time()
    r = f()
    t1 = time.time()
    print(f'Result: {r}, Time: {t1 - t0}')
model = bm.jit(LoopSimple())

# First time will trigger compilation
measure_time(model)
Result: [501.74673], Time: 2.7157142162323
# Second running
measure_time(model)
Result: [1003.49347], Time: 0.0

When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.

JAX has provided several important loop syntax, including:

BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable. Specifically, they are:

In this section, we only talk about how to use our provided loop functions.

brainpy.math.make_loop()#

brainpy.math.make_loop() is used to generate a for-loop function when you use Variable.

Suppose that you are using several JaxArrays (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as


def for_loop_function(body_fun, dyn_vars, out_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
    results = body_fun(x)
    ys.append([out_vars, results])
  return ys

In BrainPy, you can define this logic using brainpy.math.make_loop():


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)

hist_of_out_vars = loop_fun(xs)

Or,


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)

hist_of_out_vars, hist_of_return_vars = loop_fun(xs)

For the above example, we can rewrite it by using brainpy.math.make_loop as:

class LoopStruct(bp.Base):
    def __init__(self):
        super(LoopStruct, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

        def add(s): self.res += s
        self.loop = bm.make_loop(add, dyn_vars=[self.res])

    def __call__(self):
        self.loop(self.seq)
        return self.res.value
model = bm.jit(LoopStruct())

# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 0.028011560440063477
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0
brainpy.math.make_while()#

brainpy.math.make_while() is used to generate a while-loop function when you use JaxArray. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.make_while() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which does not support returned values:


while cond_fun(x):
    body_fun(x)

where x is the external input that is not iterated. All the iterated variables should be marked as JaxArray. All JaxArrays used in cond_fun and body_fun should be declared as dyn_vars variables.

Let’s look an example:

i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))

def cond_f(x):
    return i[0] < 10

def body_f(x):
    i.value += 1.
    counter.value += i

loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])

In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i and counter.

loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)

Dynamics Simulation#

Dynamical System Specification#

@Tianqiu Zhang @Chaoming Wang

BrainPy enables modularity programming and easy model debugging. To build a complex brain dynamics model, you just need to group its building blocks. In this section, we are going to talk about what building blocks we have provided, and how to use these building blocks.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

Models in brainpy.dyn#

brainpy.dyn has provided many convenient neuron, synapse, and other models for users. The following figure is a glimpse of the provided models.

The arrows in the graph represent the inheritance relations between different models.

New models will be continuously updated in the page of API documentation.

Initializing a neuron model#

All neuron models implemented in brainpy are subclasses of brainpy.dyn.NeuGroup. The initialization of a neuron model just needs to provide the geometry size of neurons in a population group.

hh = bp.dyn.HH(size=1)  # only 1 neuron

hh = bp.dyn.HH(size=10)  # 10 neurons in a group

hh = bp.dyn.HH(size=(10, 10))  # a grid of (10, 10) neurons in a group

hh = bp.dyn.HH(size=(5, 4, 2))  # a column of (5, 4, 2) neurons in a group

Generally speaking, there are two types of arguments can be set by users:

  • parameters: the model parameters, like gNa refers to the maximum conductance of sodium channel in the brainpy.dyn.HH model.

  • variables: the model variables, like V refers to the membrane potential of a neuron model.

In default, model parameters are homogeneous, which are just scalar values.

hh = bp.dyn.HH(5)  # there are five neurons in this group

hh.gNa
120.0

However, neuron models support heterogeneous parameters when performing computations in a neuron group. One can initialize heterogeneous parameters by several ways.

1. Tensor

Users can directly provide a tensor as the parameter.

hh = bp.dyn.HH(5, gNa=bm.random.uniform(110, 130, size=5))

hh.gNa
JaxArray([114.53795, 127.13995, 119.036  , 110.91665, 117.91266], dtype=float32)

2. Initializer

BrainPy provides wonderful supports on initializations. One can provide an initializer to the parameter to instruct the model initialize heterogeneous parameters.

hh = bp.dyn.HH(5, ENa=bp.init.OneInit(50.))

hh.ENa
JaxArray([50., 50., 50., 50., 50.], dtype=float32)

3. Callable function

You can also directly provide a callable function which receive a shape argument.

hh = bp.dyn.HH(5, ENa=lambda shape: bm.random.uniform(40, 60, shape))

hh.ENa
JaxArray([52.201824, 52.322166, 44.033783, 47.943596, 54.985268], dtype=float32)

Here, let’s see how the heterogeneous parameters influence our model simulation.

# we create 3 neurons in a group. Each neuron has a unique "gNa"

model = bp.dyn.HH(3, gNa=bp.init.Uniform(min_val=100, max_val=140))
runner = bp.dyn.DSRunner(model, monitors=['V'], inputs=['input', 5.])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 2], show=True)
_images/26ea3542f81b10db73896f0eb858c152d4ad9156f3e4ab82d1f323bff7d87ad6.png

Similarly, the setting of the initial values of a variable can also be realized through the above three ways: Tensor, Initializer, and Callable function. For example,

hh = bp.dyn.HH(
   3,
   V_initializer=bp.init.Uniform(-80., -60.),  # Initializer
   m_initializer=lambda shape: bm.random.random(shape),  # function
   h_initializer=bm.random.random(3),  # Tensor
)
print('V: ', hh.V)
print('m: ', hh.m)
print('h: ', hh.h)
V:  Variable([-77.707954, -73.94804 , -69.09014 ], dtype=float32)
m:  Variable([0.4219371, 0.5383264, 0.8984035], dtype=float32)
h:  Variable([0.61493886, 0.81473637, 0.3291837 ], dtype=float32)

Initializing a synapse model#

Initializing a synapse model needs to provide its pre-synaptic group (pre), post-synaptic group (post) and the connection method between them (conn). The below is an example to create an Exponential synapse model:

neu = bp.dyn.LIF(10)

# here we create a synaptic projection within a population
syn = bp.dyn.ExpCUBA(pre=neu, post=neu, conn=bp.conn.All2All())

BrainPy’s build-in synapse models support heterogeneous synaptic weights and delay steps by using Tensor, Initializer and Callable function. For example,

syn = bp.dyn.ExpCUBA(neu, neu, bp.conn.FixedProb(prob=0.1),
                     g_max=bp.init.Uniform(min_val=0.1, max_val=1.),
                     delay_step=lambda shape: bm.random.randint(10, 30, shape))
syn.g_max
JaxArray([0.9790364 , 0.18719104, 0.84017825, 0.31185275, 0.38157037,
          0.80953383, 0.61926776, 0.73845625, 0.9679548 , 0.385096  ,
          0.91454816], dtype=float32)
syn.delay_step
JaxArray([18, 19, 15, 21, 17, 24, 10, 27, 12, 20], dtype=int32)

However, in BrainPy, the built-in synapse models only support homogenous synaptic parameters, like the time constant \(\tau\). Users can customize their synaptic models when they want heterogeneous synatic parameters.

Similar, the synaptic variables can be initialized heterogeneously by using Tensor, Initializer, and Callable functions.

Change model parameters during simulation#

In BrainPy, all the dynamically changed variables (no matter it is changed inside or outside of a jitted function) should be marked as brainpy.math.Variable. BrainPy’s built-in models also support modifying model parameters during simulation.

For example, if you want to fix the gNa in the first 100 ms simulation, and then try to decrease its value in the following simulations. In this case, we can provide the gNa as an instance of brainpy.math.Variable when initializing the model.

hh = bp.dyn.HH(5, gNa=bm.Variable(bm.asarray([120.])))

runner = bp.dyn.DSRunner(hh, monitors=['V'], inputs=['input', 5.])
# the first running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/7ee69a0a4c67a1100a5e8ba14b2ad965d0314b86ed6a81af89e8bf20c2f70534.png
# change the gNa first
hh.gNa[:] = 100.

# the second running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/44a021c45bb7d1a7c50bfad36e8d2b7ef439831aaea58bbd49da932509d3ab2a.png

Examples of using built-in models#

Here we show users how to simulate a famous neuron models: The Morris-Lecar neuron model, which is a two-dimensional “reduced” excitation model applicable to systems having two non-inactivating voltage-sensitive conductances.

group = bp.dyn.MorrisLecar(1)

Then users can utilize various tools provided by BrainPy to easily simulate the Morris-Lecar neuron model. Here we are not going to dive into details so please read the corresponding tutorials if you want to learn more.

runner = bp.dyn.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.))
runner.run(1000)

fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)
_images/522c9c7ba066101ab1ad96b035b7e9ff5b5439ddbd3cab29c19a1a863a195f83.png

Next we will also give users an intuitive understanding about building a network composed of different neurons and synapses model. Users can simply initialize these models as below and pass into brainpy.dyn.Network.

neu1 = bp.dyn.HH(1)
neu2 = bp.dyn.HH(1)
syn1 = bp.dyn.AMPA(neu1, neu2, bp.connect.All2All())
net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)

By selecting proper runner, users can simulate the network efficiently and plot the simulation results.

runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g'])
runner.run(150.)

import matplotlib.pyplot as plt

fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
plt.legend()

fig.add_subplot(gs[1, 0])
plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
plt.legend()
plt.show()
_images/9a368e2e575981bcba4dbb8ae3f3365fcc468b3326027dc5825c559722a1468f.png

Building Neuron Models#

@Xiaoyu Chen @Chaoming Wang

The previous section shows all available models users can utilize by simply instantiating the abstract model. In following sections we will dive into details to illustrate how to build a neuron model with brainpy.dyn.NeuGroup. Neurons are the most basic components in neural dynamics simulation. In BrainPy, brainpy.dyn.NeuGroup is used for neuron modeling.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.dyn.NeuGroup#

Generally, any neuron model can evolve continuously or discontinuously. Discontinuous evolution may be triggered by events, such as the reset of membrane potential. Moreover, it is common in a neural system that a dynamical system has different states, such as the excitable or refractory state in a leaky integrate-and-fire (LIF) model. In this section, we will use two examples to illustrate how to capture these complexity in neuron modeling.

Defining a neuron model in BrainPy is simple. You just need to inherit from brainpy.dyn.NeuGroup, and satisfy the following two requirements:

  • Providing the size of the neural group in the constructor when initialize a new neural group class. size can be a integer referring to the number of neurons or a tuple/list of integers referring to the geometry of the neural group in different dimensions. Acoording to the provided group size, NeuroGroup will automatically calculate the total number num of neurons in this group.

  • Creating an update(_t, dt) function. Update function provides the rule how the neuron states are evolved from the current time \(\mathrm{\_t}\) to the next time \(\mathrm{\_t + \_dt}\).

In the following part, a Hodgkin-Huxley (HH) model is used as an example for illustration.

Hodgkin–Huxley Model#

The Hodgkin-Huxley (HH) model is a continuous-time dynamical system. It is one of the most successful mathematical models of a complex biological process that has ever been formulated. Changes of the membrane potential influence the conductances of different channels, elaborately modeling the neural activities in biological systems. Mathematically, the model is given by:

\[\begin{split} \begin{aligned} C_m \frac {dV} {dt} &= -(\bar{g}_{Na} m^3 h (V -E_{Na}) + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) \quad\quad(1) \\ \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} \quad\quad(2) \\ &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} \quad\quad(3) \\ &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) \quad\quad(4) \\ &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) \quad\quad(5) \\ &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} \quad\quad(6) \\ &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} \quad\quad(7) \\ &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) \quad\quad(8) \\ \end{aligned} \end{split}\]

where \(V\) is the membrane potential, \(C_m\) is the membrane capacitance per unit area, \(E_K\) and \(E_{Na}\) are the potassium and sodium reversal potentials, respectively, \(E_l\) is the leak reversal potential, \(\bar{g}_K\) and \(\bar{g}_{Na}\) are the potassium and sodium conductances per unit area, respectively, and \(\bar{g}_l\) is the leak conductance per unit area. Because the potassium and sodium channels are voltage-sensitive, according to the biological experiments, \(m\), \(n\) and \(h\) are used to simulate the activation of the channels. Speficially, \(n\) measures the activatio of potassium channels, and \(m\) and \(h\) measures the activation and inactivation of sodium channels, respectively. \(\alpha_{x}\) and \(\beta_{x}\) are rate constants for the ion channel x and depend exclusively on the membrane potential.

To implement the HH model, variables should be specified. According to the above equations, the following five state variables change with respect to time:

  • V: the membrane potential

  • m: the activation of sodium channels

  • h: the inactivation of sodium channels

  • n: the activation of potassium channels

  • input: the external/synaptic input

Besides, the spiking state and the last spiking time can also be recorded for statistic analysis:

  • spike: whether a spike is produced

  • t_last_spike: the last spiking time

Based on these state variables, the HH model can be implemented as below.

class HH(bp.dyn.NeuGroup):
  def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
               V_th=20., C=1.0, **kwargs):
    # providing the group "size" information
    super(HH, self).__init__(size=size, **kwargs)

    # initialize parameters
    self.ENa = ENa
    self.EK = EK
    self.EL = EL
    self.gNa = gNa
    self.gK = gK
    self.gL = gL
    self.C = C
    self.V_th = V_th

    # initialize variables
    self.V = bm.Variable(bm.random.randn(self.num) - 70.)
    self.m = bm.Variable(0.5 * bm.ones(self.num))
    self.h = bm.Variable(0.6 * bm.ones(self.num))
    self.n = bm.Variable(0.32 * bm.ones(self.num))
    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

    # integral functions
    self.int_V = bp.odeint(f=self.dV, method='exp_auto')
    self.int_m = bp.odeint(f=self.dm, method='exp_auto')
    self.int_h = bp.odeint(f=self.dh, method='exp_auto')
    self.int_n = bp.odeint(f=self.dn, method='exp_auto')

  def dV(self, V, t, m, h, n, Iext):
    I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
    I_K = (self.gK * n ** 4.0) * (V - self.EK)
    I_leak = self.gL * (V - self.EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
    return dVdt

  def dm(self, m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return dmdt
  
  def dh(self, h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return dhdt

  def dn(self, n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return dndt

  def update(self, _t, _dt):
    # compute V, m, h, n
    V = self.int_V(self.V, _t, self.m, self.h, self.n, self.input, dt=_dt)
    self.h.value = self.int_h(self.h, _t, self.V, dt=_dt)
    self.m.value = self.int_m(self.m, _t, self.V, dt=_dt)
    self.n.value = self.int_n(self.n, _t, self.V, dt=_dt)

    # update the spiking state and the last spiking time
    self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)

    # update V
    self.V.value = V

    # reset the external input
    self.input[:] = 0.

When defining the HH model, equation (1) is accomplished by brainpy.odeint as an ODEIntegrator. The details are contained in the Numerical Solvers for ODEs tutorial.

The variables, which will be updated during dynamics simulation, should be packed as brainpy.math.Variable and thus can be processed by JIT compliers to accelerate simulation.

In the following part, a leaky integrate-and-fire (LIF) model is introduced as another example for illustration.

Leaky Integrate-and-Fire Model#

The LIF model is the classical neuron model which contains a continuous process and a discontinous spike reset operation. Formally, it is given by:

\[\begin{split} \begin{aligned} \tau_m \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \quad\quad (1) \\ \text{if} \, V(t) \gt V_{th}, V(t) =V_{rest} \, \text{after} \, \tau_{ref} \, \text{ms} \quad\quad (2) \end{aligned} \end{split}\]

where \(V\) is the membrane potential, \(V_{rest}\) is the rest membrane potential, \(V_{th}\) is the spike threshold, \(\tau_m\) is the time constant, \(\tau_{ref}\) is the refractory time period, and \(I\) is the time-variant synaptic inputs.

The above two equations model the continuous change and the spiking of neurons, respectively. Moreover, it has multiple states: subthreshold state, and spiking or refractory state. The membrane potential \(V\) is integrated according to equation (1) when it is below \(V_{th}\). Once \(V\) reaches the threshold \(V_{th}\), according to equation (2), \(V\) is reaet to \(V_{rest}\), and the neuron enters the refractory period where the membrane potential \(V\) will remain constant in the following \(\tau_{ref}\) ms.

The neuronal variables, like the membrane potential and external input, can be captured by the following two variables:

  • V: the membrane potential

  • input: the external/synaptic input

In order to define the different states of a LIF neuron, we define additional variables:

  • spike: whether a spike is produced

  • refractory: whether the neuron is in the refractory period

  • t_last_spike: the last spiking time

Based on these state variables, the LIF model can be implemented as below.

class LIF(bp.dyn.
          NeuGroup):
  def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., R=1., tau=10., t_ref=5., **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)

    # initialize parameters
    self.V_rest = V_rest
    self.V_reset = V_reset
    self.V_th = V_th
    self.R = R
    self.tau = tau
    self.t_ref = t_ref

    # initialize variables
    self.V = bm.Variable(bm.random.randn(self.num) + V_reset)
    self.input = bm.Variable(bm.zeros(self.num))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
    self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # integral function
    self.integral = bp.odeint(f=self.derivative, method='exp_auto')

  def derivative(self, V, t, Iext):
    dvdt = (-V + self.V_rest + self.R * Iext) / self.tau
    return dvdt

  def update(self, _t, _dt):
    # Whether the neurons are in the refractory period
    refractory = (_t - self.t_last_spike) <= self.t_ref
    
    # compute the membrane potential
    V = self.integral(self.V, _t, self.input, dt=_dt)
    
    # computed membrane potential is valid only when the neuron is not in the refractory period 
    V = bm.where(refractory, self.V, V)
    
    # update the spiking state
    spike = self.V_th <= V
    self.spike.value = spike
    
    # update the last spiking time
    self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
    
    # update the membrane potential and reset spiked neurons
    self.V.value = bm.where(spike, self.V_reset, V)
    
    # update the refractory state
    self.refractory.value = bm.logical_or(refractory, spike)
    
    # reset the external input
    self.input[:] = 0.

In above, the discontinous resetting is implemented with brainpy.math.where operation.

Instantiation and running#

Here, let’s try to instantiate a HH neuron group:

neu = HH(10)

in which a neural group containing 10 HH neurons is generated.

The details of the model simulation will be expanded in the Runners section. In brief, running any dynamical system instance should be accomplished with a runner, such like brianpy.DSRunner and brainpy.ReportRunner. The variables to be monitored and the input crrents to be applied in the simulation can be provided when initializing the runner. The details are accessible in Monitors and Inputs.

runner = bp.dyn.DSRunner(
    neu, 
    monitors=['V'], 
    inputs=('input', 22.)  # constant external inputs of 22 mA to all neurons
)

Then the simulation can be performed with a given time period, and the simulation result can be visualized:

runner.run(200)  # the running time is 200 ms

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/8901e068faf7f731e655ca2ee003d398e85542ba560e2921a8baa7ceff4ac3d4.png

A LIF neural group can be instantiated and applied in simulation in a similar way:

group = LIF(10)

runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 22.), jit=True)
runner.run(200)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/67b3aa7a0437c8bc5fa85e27c29bd05f9fef5f3282716195abc3fa1974e33599.png

Building Synapse Models#

@Chaoming Wang @Xiaoyu Chen

Synaptic computation is the core of brain dynamics programming. This is beacuse in a real project most of the simulation time spends on the computation of synapses. In order to achieve efficient synaptic computation, BrainPy provides many useful supports. Here, we are going to explore the details of these supports.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')

Synapse Models in Math#

Before we talk about the implementation of synapses in BrainPy, it’s better to understand the targets (synapse models) we are going to implement. For different illustration purposes, we are going to implement two synapse models: exponential synapse model and AMPA synapse model.

1. The exponential synapse model#

The exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant \(\tau_{decay}\). Its dynamics is given by:

\[ \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-D-t^{k}) \]

where \(g\) is the synaptic state, \(t^{k}\) is the spike time of the pre-synaptic neuron, and \(D\) is the synaptic delay.

Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form:

\[ I_{syn}(t) = g_{max} g \left( V-E \right) \]

where \(E\) is the reversal potential of the synapse, \(V\) is the post-synaptic membrane potential, \(g_{max}\) is the maximum synaptic conductance.

2. The AMPA synapse model#

A classical model of AMPA synapse is to use the Markov process to model ion channel switch. Here \(g\) represents the probability of channel opening, \(1-g\) represents the probability of ion channel closing, and \(\alpha\) and \(\beta\) are the transition probability. Specifically, its formula is given by

\[ \frac{dg}{dt} =\alpha[T](1-g)-\beta g \]

where \(\alpha [T]\) denotes the transition probability from state \((1-g)\) to state \((g)\); and \(\beta\) represents the transition probability of the other direction. \(\alpha\) is the binding constant. \(\beta\) is the unbinding constant. \([T]\) is the neurotransmitter concentration, and has the duration of 0.5 ms.

Moreover, the post-synaptic current on the post-synaptic neuron is formulated as

\[I_{syn} = g_{max} g (V-E)\]

where \(g_{max}\) is the maximum conductance, and \(E\) is the reverse potential.

Synapse Models in Silicon#

The implementation of synapse models is accomplished by brainpy.dyn.TwoEndConn interface. In this section, we talk about what supports are provided for the implementation of synapse models in silicon.

1. brainpy.dyn.TwoEndConn#

In BrainPy, brainpy.dyn.TwoEndConn is used to model two-end synaptic computations.

To define a synapse model, two requirements should be satisfied:

1. Constructor function __init__(), in which three key arguments are needed.

  • pre: the pre-synaptic neural group. It should be an instance of brainpy.dyn.NeuGroup.

  • post: the post-synaptic neural group. It should be an instance of brainpy.dyn.NeuGroup.

  • conn (optional): the connection type between these two groups. BrainPy has provided abundant connection types that are described in details in the Synaptic Connections.

2. Update function update(_t, _dt) describes the updating rule from the current time \(\mathrm{\_t}\) to the next time \(\mathrm{\_t + \_dt}\).

2. Variable delays#

As seen in the above two synapse models, synaptic computations are usually involved with variable delays. A delay time (typically 0.3–0.5 ms) is usually required for a neurotransmitter to be released from a presynaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane.

BrainPy provides several kinds of delay variables for users, including:

  • brainpy.math.LengthDelay: a delay variable which defines a constant steps for delay.

  • brainpy.math.TimeDelay: a delay variable which defines a constant time length for delay.

Assume here we need a delay variable which has 1 ms delay. If the numerical integration precision dt is 0.1 ms, then we can create a brainpy.math.LengthDelay which has 10 delay time steps.

target_data_to_delay = bm.Variable(bm.zeros(10))

example_delay = bm.LengthDelay(target_data_to_delay,
                               delay_len=10)  # delay 10 steps
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
example_delay(5)  # call the delay data at 5 delay step
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(10)  # call the delay data at 10 delay step
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

Alternatively, we can create an instance of brainpy.math.TimeDelay, which use time t as the index to retrieve the delay data.

t0 = 0.
example_delay = bm.TimeDelay(target_data_to_delay,
                             delay_len=1.0, t0=t0)  # delay 1.0 ms
example_delay(t0 - 1.0)  # the delay data at t-1. ms
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(t0 - 0.5)  # the delay data at t-0.5 ms
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
3. Synaptic connections#

Synaptic computations usually need to create connection between groups. BrainPy provides many wonderful supports to construct synaptic connections. Simply speaking, brainpy.conn.Connector can create various data sturctures you want through the require() function. Take the random connection brainpy.conn.FixedProb which will be used in follows as the example,

example_conn = bp.conn.FixedProb(0.2)(pre_size=(5,), post_size=(8, ))

we can require the connection matrix (has the shape of (num_pre, num_post):

example_conn.require('conn_mat')
JaxArray([[False, False, False, False,  True, False, False, False],
          [False, False, False, False, False, False,  True, False],
          [False, False, False, False, False,  True, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False]],            dtype=bool)

we can also require the connected indices of pre-synaptic neurons (pre_ids) and post-synaptic neurons (post_ids):

example_conn.require('pre_ids', 'post_ids')
(JaxArray([0, 0, 1, 1, 2, 2, 3, 4, 4, 4, 4], dtype=uint32),
 JaxArray([1, 4, 4, 5, 2, 3, 6, 1, 5, 6, 7], dtype=uint32))

Or, we can require the connection structure of pre2post which stores the information how does each pre-synaptic neuron connect to post-synaptic neurons:

example_conn.require('pre2post')
(JaxArray([0, 3, 4, 1, 0, 2, 7], dtype=uint32),
 JaxArray([0, 3, 3, 4, 7, 7], dtype=uint32))

More details of the connection structures please see the tutorial of Synaptic Connections.

Achieving efficient synaptic computation is difficult#

Synaptic computations usually need to transform the data of the pre-synaptic dimension into the data of the post-synaptic dimension, or the data with the shape of the synapse number. There does not exist a universal computation method that are efficient in all cases. Usually, we need different ways for different connection situations to achieve efficient synaptic computation. In the next two sections, we will talk about how to define efficient synaptic models when your connections are sparse or dense.

Before we start, we need to define some useful helper functions to define and show synapse models. Then, we will highlight the key differences of model difinition when using different synaptic connections.

# Basic Model to define the exponential synapse model. This class 
# defines the basic parameters, variables, and integral functions. 


class BaseExpSyn(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., method='exp_auto'):
    super(BaseExpSyn, self).__init__(pre=pre, post=post, conn=conn)

    # check whether the pre group has the needed attribute: "spike"
    self.check_pre_attrs('spike')

    # check whether the post group has the needed attribute: "input" and "V"
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max

    # use "LengthDelay" to store the spikes of the pre-synaptic neuron group
    self.delay_step = int(delay/bm.get_dt())
    self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)

    # integral function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
# Basic Model to define the AMPA synapse model. This class 
# defines the basic parameters, variables, and integral functions. 


class BaseAMPASyn(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, delay=0., g_max=0.42, E=0., alpha=0.98,
               beta=0.18, T=0.5, T_duration=0.5, method='exp_auto'):
    super(BaseAMPASyn, self).__init__(pre=pre, post=post, conn=conn)

    # check whether the pre group has the needed attribute: "spike"
    self.check_pre_attrs('spike')

    # check whether the post group has the needed attribute: "input" and "V"
    self.check_post_attrs('input', 'V')

    # parameters
    self.delay = delay
    self.g_max = g_max
    self.E = E
    self.alpha = alpha
    self.beta = beta
    self.T = T
    self.T_duration = T_duration

    # use "LengthDelay" to store the spikes of the pre-synaptic neuron group
    self.delay_step = int(delay/bm.get_dt())
    self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)

    # store the arrival time of the pre-synaptic spikes
    self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)

    # integral function
    self.integral = bp.odeint(self.derivative, method=method)

  def derivative(self, g, t, TT):
    dg = self.alpha * TT * (1 - g) - self.beta * g
    return dg
# for more details of how to run a simulation please see the tutorials in "Dynamics Simulation"

def show_syn_model(model):
  pre = bp.models.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
  post = bp.models.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
  syn = model(pre, post, conn=bp.conn.One2One())
  net = bp.dyn.Network(pre=pre, post=post, syn=syn)

  runner = bp.DSRunner(net,
                       monitors=['pre.V', 'post.V', 'syn.g'],
                       inputs=['pre.input', 22.])
  runner.run(100.)

  fig, gs = bp.visualize.get_figure(1, 2, 3, 4)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['syn.g'], legend='syn.g')
  fig.add_subplot(gs[0, 1])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['pre.V'], legend='pre.V')
  bp.visualize.line_plot(runner.mon.ts, runner.mon['post.V'], legend='post.V', show=True)

Computation with Dense Connections#

Matrix-based synaptic computation is straightforward. Especially, when your models are connected densely, using matrix is highly efficient.

conn_mat#

Assume two neuron groups are connected through a fixed probability of 0.7.

conn = bp.conn.FixedProb(0.7)(pre_size=6, post_size=8)

Then you can create the connection matrix though conn.require("conn_mat"):

conn.require('conn_mat')
JaxArray([[False,  True, False,  True,  True, False,  True,  True],
          [ True,  True,  True, False,  True,  True,  True,  True],
          [False,  True,  True,  True,  True,  True,  True,  True],
          [ True,  True,  True, False,  True,  True,  True,  True],
          [False,  True, False,  True,  True,  True,  True, False],
          [ True, False,  True,  True, False,  True, False,  True]],            dtype=bool)

conn_mat has the shape of (num_pre, num_post). Therefore, transforming the data with the pre-synaptic dimension into the date of the post-synaptic dimension is very easy. You just need make a matrix multiplication: brainpy.math.dot(pre_values,  conn_mat) (\(\mathbb{R}^\mathrm{num\_pre} @ \mathbb{R}^\mathrm{(num\_pre, num\_post)} \to \mathbb{R}^\mathrm{num\_post}\)).

With the synaptic connection of conn_mat in above, we can define the exponential synapse model as the follows. It’s worthy to note that the evolution of states ouput onto the same post-synaptic neurons in exponential synapses can be superposed. This means we can declare the synapse variables with the shape of post-synaptic group, rather than the number of the total synapses.

class ExpConnMat(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpConnMat, self).__init__(*args, **kwargs)

    # connection matrix
    self.conn_mat = self.conn.require('conn_mat')

    # synapse gating variable
    # -------
    # NOTE: Here the synapse number is the same with 
    #       the post-synaptic neuron number. This is 
    #       different from the AMPA synapse.
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, _t, _dt):
    # pull the delayed pre spikes for computation
    delayed_spike = self.pre_spike(self.delay_step)
    # push the latest pre spikes into the bottom
    self.pre_spike.update(self.pre.spike)
    # integrate the synapse state
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # update synapse states according to the pre spikes
    post_sps = bm.dot(delayed_spike, self.conn_mat)
    self.g += post_sps
    # get the post-synaptic current
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpConnMat)
_images/5bfb2760e001f4b7733d3c7c16bc22c26134ebbd862db0438b04a5387129f72e.png

We can also use conn_mat to define an AMPA synapse model. Note here the shape of the synapse variable \(g\) is (num_pre, num_post), rather than self.post.num in the above exponential synapse model. This is because the synaptic states of AMPA model can not be superposed.

class AMPAConnMat(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPAConnMat, self).__init__(*args, **kwargs)

    # connection matrix
    self.conn_mat = self.conn.require('conn_mat')

    # synapse gating variable
    # -------
    # NOTE: Here the synapse shape is (num_pre, num_post),
    #       in contrast to the ExpConnMat
    self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))

  def update(self, _t, _dt):
    # pull the delayed pre spikes for computation
    delayed_spike = self.pre_spike(self.delay_step)
    # push the latest pre spikes into the bottom
    self.pre_spike.update(self.pre.spike)
    # get the time of pre spikes arrive at the post synapse
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    # get the neurotransmitter concentration at the current time
    TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
    # integrate the synapse state
    TT = TT.reshape((-1, 1)) * self.conn_mat  # NOTE: only keep the concentrations
                                              #       on the invalid connections
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    # get the post-synaptic current
    g_post = self.g.sum(axis=0)
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAConnMat)
_images/69c076275f01603ec59024c4cb55a99701003ee054fbad7dfc035287fa08b1e8.png
Special connections#

Sometimes, we can define some synapse models with special connection types, such as all-to-all connection, or one-to-one connection. For these special situations, even the connection information can be ignored, i.e., we do not need conn_mat or other structures any more.

Assume the pre-synaptic group connects to the post-synaptic group with a all-to-all fashion. Then, exponential synapse model can be defined as,

class ExpAll2All(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpAll2All, self).__init__(*args, **kwargs)

    # synapse gating variable
    # -------
    # The synapse variable has the shape of the post-synaptic group
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, _t, _dt):
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.g.value = self.integral(self.g, _t, dt=_dt)
    self.g += delayed_spike.sum()  # NOTE: HERE is the difference
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpAll2All)
_images/5bfb2760e001f4b7733d3c7c16bc22c26134ebbd862db0438b04a5387129f72e.png

Similarly, the AMPA synapse model can be defined as

class AMPAAll2All(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPAAll2All, self).__init__(*args, **kwargs)

    # synapse gating variable
    # -------
    # The synapse variable has the shape of the post-synaptic group
    self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))

  def update(self, _t, _dt):
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
    TT = TT.reshape((-1, 1))  # NOTE: here is the difference
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    g_post = self.g.sum(axis=0) # NOTE: here is also different
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAAll2All)
_images/69c076275f01603ec59024c4cb55a99701003ee054fbad7dfc035287fa08b1e8.png

Actually, the synaptic computation with these special connections can be very efficient! A concrete example please see a decision making spiking model in BrainPy-Examples. This implementation achievew at least four times acceleration comparing to the implementation in other frameworks.

Computation with Sparse Connections#

However, in the real neural system, the neurons are connected sparsely in essence.

Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need \(10^8\) floats to save the synaptic state, and at each update step, you need do computation on \(10^8\) floats. Actually, the number of synapses you really connect is only \(10^7\). See, there is a huge memory waste and computing resource inefficiency. Moreover, at the given time \(\mathrm{\_t}\), the number of pre-synaptic neurons in the spiking state is small on average. This means we have made many useless computations when defining synaptic computations with matrix-based connections (zeros dot connection matrix results in zeros).

Therefore, we need new ways to define synapse models. Specifically, we use vectors to store the connected neuron indices, like the pre_ids and post_ids (see Synaptic Connections).

In the below, we assume you have learned the synaptic connection types detailed in the tutorial of Synaptic Connections.

The pre2post operator#

A notable difference of brain dynamics models from the deep learning is that they are sparse and event-driven. In order to support this significant different kind of computations, BrainPy has built many useful operators. In this section, we talk about a set of operators needed in pre2post computations.

Note before we have said that exponential synapse model can make computations at the dimension of the post-synaptic group. Therefore, we can directly transform the pre-synaptic data into the data of the post-synaptic shape. brainpy.math.pre2post_event_sum(events, pre2post, post_num, values) can satisfy your requirements. This operator needs the synaptic structure of pre2post (a tuple contains the post_ids and idnptr of pre-synaptic neurons).

If values is a scalar, pre2post_event_sum is equivalent to:

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
  if events[i]:
    for j in range(idnptr[i], idnptr[i+1]):
      post_val[post_ids[i]] += values

If values is a vector, pre2post_event_sum is equivalent to:

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
  if events[i]:
    for j in range(idnptr[i], idnptr[i+1]):
      post_val[post_ids[i]] += values[j]

With this operator, exponential synapse model can be defined as:

class ExpSparse(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpSparse, self).__init__(*args, **kwargs)

    # connections
    self.pre2post = self.conn.require('pre2post')

    # synapse variable
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, _t, _dt):
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # NOTE: update synapse states according to the pre spikes
    post_sps = bm.pre2post_event_sum(delayed_spike, self.pre2post, self.post.num, 1.)
    self.g += post_sps
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpSparse)
_images/5bfb2760e001f4b7733d3c7c16bc22c26134ebbd862db0438b04a5387129f72e.png

This model will be very efficient when your synapses are connected sparsely.

The pre2syn and syn2post operators#

However, for AMPA synapse model, the pre-synaptic values can not be directly transformed into the post-synaptic dimensional data. Therefore, we need to first change the pre data into the data of the synapse dimension, then transform the synapse-dimensional data into the post-dimensional data.

Therefore, the core problem of synaptic computation is how to convert values among different shape of tensors. Specifically, in the above AMPA synapse model, we have three kinds of tensor shapes (see the following figure): tensors with the dimension of pre-synaptic group, tensors of the dimension of post-synaptic group, and tensors with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.

Here BrainPy provides two operators brainpy.math.pre2syn(pre_values, pre_ids) and brainpy.math.syn2post(syn_values, post_ids, post_num) to convert vectors among different dimensions.

  • brainpy.math.pre2syn() receives two arguments: “pre_values” (the variable of the pre-synaptic dimension) and “pre_ids” (the connected pre-synaptic neuron index).

  • brainpy.math.syn2post() receives three arguments: “syn_values” (the variable with the synaptic size), “post_ids” (the connected post-synaptic neuron index) and “post_num” (the number of the post-synaptic neurons).

Based on these two operators, we can define the AMPA synapse model as:

class AMPASparse(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPASparse, self).__init__(*args, **kwargs)

    # connection matrix
    self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')

    # synapse gating variable
    # -------
    # NOTE: Here the synapse shape is (num_syn,)
    self.g = bm.Variable(bm.zeros(len(self.pre_ids)))

  def update(self, _t, _dt):
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    # get the time of pre spikes arrive at the post synapse
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    # get the arrival time with the synapse dimension
    arrival_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
    # get the neurotransmitter concentration at the current time
    TT = ((_t - arrival_times) < self.T_duration) * self.T
    # integrate the synapse state
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    # get the post-synaptic current
    g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPASparse)
_images/69c076275f01603ec59024c4cb55a99701003ee054fbad7dfc035287fa08b1e8.png

We hope this tutorial will help your synapse models be defined efficiently.

Building Network Models#

@Xiaoyu Chen @Chaoming Wang

In previous sections, it has been illustrated how to define neuron models by brainpy.dyn.NeuGroup and synapse models by brainpy.dyn.TwoEndConn. This section will introduce brainpy.dyn.Network, which is the base class used to build network models.

In essence, brainpy.dyn.Network is a container, whose function is to compose the individual elements. It is a subclass of a more general class: brainpy.dyn.Container.

In below, we take an excitation-inhibition (E-I) balanced network model as an example to illustrate how to compose the LIF neurons and Exponential synapses defined in previous tutorials to build a network.

import brainpy as bp

bp.math.set_platform('cpu')

Excitation-Inhibition (E-I) Balanced Network#

The E-I balanced network was first proposed to explain the irregular firing patterns of cortical neurons and comfirmed by experimental data. The network [1] we are going to implement consists of excitatory (E) neurons and inhibitory (I) neurons, the ratio of which is about 4 : 1. The biggest difference between excitatory and inhibitory neurons is the reversal potential - the reversal potential of inhibitory neurons is much lower than that of excitatory neurons. Besides, the membrane time constant of inhibitory neurons is longer than that of excitatory neurons, which indicates that inhibitory neurons have slower dynamics.

[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98.

# BrianPy has some built-in conanical neuron and synapse models

LIF = bp.dyn.neurons.LIF
ExpCOBA = bp.dyn.synapses.ExpCOBA

Two ways to define network models#

There are several ways to define a Network model.

1. Defining a network as a class#

The first way to define a network model is like follows.

class EINet(bp.dyn.Network):
  def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
    super(EINet, self).__init__(**kwargs)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = LIF(num_exc, **pars, method=method)
    I = LIF(num_inh, **pars, method=method)
    E.V.value = bp.math.random.randn(num_exc) * 2 - 55.
    I.V.value = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    w_e = 0.6  # excitatory synaptic weight
    w_i = 6.7  # inhibitory synaptic weight
    E_pars = dict(E=0., g_max=w_e, tau=5.)
    I_pars = dict(E=-80., g_max=w_i, tau=10.)
    
    # Neurons connect to each other randomly with a connection probability of 2%
    self.E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
    self.E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
    self.I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
    self.I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)

    self.E = E
    self.I = I

In an instance of brainpy.dyn.Network, all self. accessed elements can be gathered by the .child_ds() function automatically.

EINet(8, 2).child_ds()
{'ExpCOBA0': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x1076db580>,
 'ExpCOBA1': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x1584c3d30>,
 'ExpCOBA2': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x158496a00>,
 'ExpCOBA3': <brainpy.dyn.synapses.abstract_models.ExpCOBA at 0x15880ce80>,
 'LIF0': <brainpy.dyn.neurons.IF_models.LIF at 0x1583fd100>,
 'LIF1': <brainpy.dyn.neurons.IF_models.LIF at 0x1583f6d30>,
 'ConstantDelay0': <brainpy.dyn.base.ConstantDelay at 0x1584ba1c0>,
 'ConstantDelay1': <brainpy.dyn.base.ConstantDelay at 0x1584c3d00>,
 'ConstantDelay2': <brainpy.dyn.base.ConstantDelay at 0x158496370>,
 'ConstantDelay3': <brainpy.dyn.base.ConstantDelay at 0x15880c340>}

Note in the above EINet, we do not define the update() function. This is because any subclass of brainpy.dyn.Network has a default update function, in which it automatically gathers the elements defined in this network and sequentially runs the update function of each element.

If you have some special operations in your network, you can override the update function by yourself. Here is a simple example.

class ExampleToOverrideUpdate(EINet):
    def update(self, _t, _dt):
        for node in self.child_ds().values():
            node.update(_t, _dt)

Let’s try to simulate our defined EINet model.

net = EINet(3200, 800, method='exp_auto')  # "method": the numerical integrator method

runner = bp.dyn.DSRunner(net,
                         monitors=['E.spike', 'I.spike'],
                         inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
                         title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
                         title='Spikes of Inhibitory Neurons', show=True)
Used time 0.35350608825683594 s
_images/f3cc90855402acc386fb06aa4927405b0e7e9602bc3310f0d83e4e381a12a2af.png _images/fe1a5f51a2394c78c59d151b5da06d7e34e842b764129bab8f802956b17ea9bc.png
2. Instantiating a network directly#

Another way to instantiate a network model is directly pass the elements into the constructor of brainpy.Network. It receives *args and **kwargs arguments.

# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(3200, **pars)
I = LIF(800, **pars)
E.V.value = bp.math.random.randn(E.num) * 2 - 55.
I.V.value = bp.math.random.randn(I.num) * 2 - 55.

# synapses
E_pars = dict(E=0., g_max=0.6, tau=5.)
I_pars = dict(E=-80., g_max=6.7, tau=10.)
E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), **E_pars)
E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), **E_pars)
I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), **I_pars)
I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), **I_pars)


# Network
net2 = bp.dyn.Network(E2E, E2I, I2E, I2I, exc_group=E, inh_group=I)

All elements are passed as **kwargs argument can be accessed by the provided keys. This will affect the following dynamics simualtion and will be discussed in greater detail in tutorial of Runners.

net2.exc_group
<brainpy.dyn.neurons.IF_models.LIF at 0x159470c10>
net2.inh_group
<brainpy.dyn.neurons.IF_models.LIF at 0x159470d30>

After construction, the simulation goes the same way:

runner = bp.dyn.DSRunner(net2,
                         monitors=['exc_group.spike', 'inh_group.spike'],
                         inputs=[('exc_group.input', 20.), ('inh_group.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['exc_group.spike'],
                         title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['inh_group.spike'],
                         title='Spikes of Inhibitory Neurons', show=True)
Used time 0.3590219020843506 s
_images/7a135e55bdc545b9a84f4dce512ed8b60b05aad5f1476205496e65a86d2fc639.png _images/0a4856d38f477c471ed1a120499fc13e53a7c383c163aefa454938e1e97ea611.png

Above are some simulation examples showing the possible application of network models. The detailed description of dynamics simulation is covered in the toolboxes, where the use of runners, monitors, and inputs will be expatiated.

Building General Dynamical Systems#

@Xiaoyu Chen @Chaoming Wang

The previous sections have shown how to build neuron models, synapse models, and network models. In fact, these brain objects all inherit the base class brainpy.dyn.DynamicalSystem, brainpy.dyn.DynamicalSystem is the universal language to define dynamical models in BrainPy.

To begin with, let’s make a rief summary of previous dynamic models and give the definition of a dynamical system.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

What is a dynamical system?#

Looking back to the neuron and synapse models defined in the previous sections, they share a common feature that they all contain some variables that change over time. Because of these variables, the models become ‘dynamic’ and behave differently at different times.

Actually, a dynamical system is defined as a system with time-dependent states. These time-dependent states are displayed as variables in the previous models.

Mathematically, the change of a state \(X\) can be expressed as

\[ \dot{X} = f(X, t) \]

where \(X\) is the state of the system, \(t\) is the time, and \(f\) is a function describing the time dependence of the state.

Alternatively, the evolution of the system over time can be given by

\[ X(t+dt) = F\left(X(t), t, dt\right) \]

where \(dt\) is the time step and \(F\) is the evolution rule to update the system’s state.

Customizing your dynamical systems#

According to the mathematical expression of a dynamical system, any subclass of brainpy.dyn.DynamicalSystem must implement an updating rule in the update(self, t, dt) function.

To define a dynamical system, the following requirements should be satisfied:

  • Inherit from brainpy.dyn.DynamicalSystem.

  • Implement the update(self, t, dt) function.

  • When defining variables, they should be declared as brainpy.math.Variable.

  • When updating the variables, it should be realized by in-place operations.

Below is a simple example of a dynamical system.

class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
    def __init__(self, a=0.8, b=0.7, tau=12.5, **kwargs):
        super(FitzHughNagumoModel, self).__init__(**kwargs)
        
        # parameters
        self.a = a
        self.b = b
        self.tau = tau
        
        # variables should be packed by brainpy.math.Variable
        self.v = bm.Variable([0.])
        self.w = bm.Variable([0.])
        self.I = bm.Variable([0.])
        
    def update(self, _t, _dt):
        # _t : the current time, the system keyword 
        # _dt : the time step, the system keyword 
        
        # in-place update
        self.w += (self.v + self.a - self.b * self.w) / self.tau * _dt
        self.v += (self.v - self.v ** 3 / 3 - self.w + self.I) * _dt
        self.I[:] = 0.

Here, we have defined a dynamical system called FitzHugh–Nagumo neuron model, whose dynamics is given by:

\[\begin{split} {\dot {v}}=v-{\frac {v^{3}}{3}}-w+I, \\ \tau {\dot {w}}=v+a-bw. \end{split}\]

By using the Euler method, this system can be updated by the following rule:

\[\begin{split} \begin{aligned} v(t+dt) &= v(t) + [v(t)-{v(t)^{3}/3}-w(t)+RI] * dt, \\ w(t + dt) &= w(t) + [v(t) + a - b w(t)] * dt. \end{aligned} \end{split}\]

Advantages of using brainpy.dyn.DynamicalSystem#

There are several advantages of defining a dynamical system as brainpy.dyn.DynamicalSystem.

1. A systematic naming system.#

First, every instance of DynamicalSystem has its unique name.

fhn = FitzHughNagumoModel()
fhn.name  # name for "fhn" instance
'FitzHughNagumoModel1'

Every instance has its unique name:

for _ in range(3):
    print(FitzHughNagumoModel().name)
FitzHughNagumoModel2
FitzHughNagumoModel3
FitzHughNagumoModel4

Users can also specify the name of a dynamic system:

fhn2 = FitzHughNagumoModel(name='X')

fhn2.name
'X'
# same name will cause error

try:
    FitzHughNagumoModel(name='X')
except bp.errors.UniqueNameError as e:
    print(e)
In BrainPy, each object should have a unique name. However, we detect that <__main__.FitzHughNagumoModel object at 0x000001F75163C250> has a used name "X". 
If you try to run multiple trials, you may need 

>>> brainpy.base.clear_name_cache() 

to clear all cached names. 

Second, variables, children nodes, etc. inside an instance can be easily accessed by their absolute or relative path.

# All variables can be acessed by 
# 1). the absolute path

fhn2.vars()
{'X.I': Variable([0.], dtype=float32),
 'X.v': Variable([0.], dtype=float32),
 'X.w': Variable([0.], dtype=float32)}
# 2). or, the relative path

fhn2.vars(method='relative')
{'I': Variable([0.], dtype=float32),
 'v': Variable([0.], dtype=float32),
 'w': Variable([0.], dtype=float32)}
2. Convenient operations for simulation and analysis.#

Brainpy provides different runners for dynamics simulation and analyzers for dynamics analysis, both of which require the dynamic model to be Brainpy.dyn.DynamicalSystem. For example, dynamic models can be packed by a runner for simulation:

runner = bp.dyn.DSRunner(fhn2, monitors=['v', 'w'], inputs=('I', 1.5))
runner(duration=100)

bp.visualize.line_plot(runner.mon.ts, runner.mon.v, legend='v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)
_images/1d6c842b89995512f4131dde31055b898f1792fda5733429a944f69838e70c3d.png

Please see Runners to know more about the operations in runners.

3. Efficient computation.#

brainpy.dyn.DynamicalSystem is a subclass of brainpy.Base, and therefore, any instance of brainpy.dyn.DynamicalSystem can be complied just-in-time into efficient machine codes targeting on CPUs, GPUs, and TPUs.

runner = bp.dyn.DSRunner(fhn2, monitors=['v', 'w'], inputs=('I', 1.5), jit=True)
runner(duration=100)

bp.visualize.line_plot(runner.mon.ts, runner.mon.v, legend='v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)
_images/2163b84aaaa513520c713fa2d6a9557d2931bd0d332090ac33c26daf4cdf5a5a.png
4. Support composable programming.#

Instances of brainpy.dyn.DynamicalSystem can be combined at will. The combined system is also a brainpy.dyn.DynamicalSystem and enjoys all the properties, methods, and interfaces provided by brainpy.dyn.DynamicalSystem.

For example, if the instances are wrapped into a container, i.e. brainpy.dyn.Network, variables and nodes can also be accessed by their absolute or relative path.

fhn_net = bp.dyn.Network(f1=fhn, f2=fhn2)
# absolute access of variables

fhn_net.vars()
{'FitzHughNagumoModel1.I': Variable([0.], dtype=float32),
 'FitzHughNagumoModel1.v': Variable([0.], dtype=float32),
 'FitzHughNagumoModel1.w': Variable([0.], dtype=float32),
 'X.I': Variable([0.], dtype=float32),
 'X.v': Variable([1.492591], dtype=float32),
 'X.w': Variable([1.9365357], dtype=float32)}
# relative access of variables

fhn_net.vars(method='relative')
{'f1.I': Variable([0.], dtype=float32),
 'f1.v': Variable([0.], dtype=float32),
 'f1.w': Variable([0.], dtype=float32),
 'f2.I': Variable([0.], dtype=float32),
 'f2.v': Variable([1.492591], dtype=float32),
 'f2.w': Variable([1.9365357], dtype=float32)}
# absolute access of nodes

fhn_net.nodes()
{'FitzHughNagumoModel1': <__main__.FitzHughNagumoModel at 0x1f7515a74c0>,
 'X': <__main__.FitzHughNagumoModel at 0x1f75164bd90>,
 'Network0': <brainpy.dyn.base.Network at 0x1f7529e70d0>}
# relative access of nodes

fhn_net.nodes(method='relative')
{'': <brainpy.dyn.base.Network at 0x1f7529e70d0>,
 'f1': <__main__.FitzHughNagumoModel at 0x1f7515a74c0>,
 'f2': <__main__.FitzHughNagumoModel at 0x1f75164bd90>}
runner = bp.dyn.DSRunner(fhn_net,
                         monitors=['f1.v', 'X.v'], 
                         inputs=[('f1.I', 1.5),   # relative access to variable "I" in 'fhn1'
                                 ('X.I', 1.0),])  # absolute access to variable "I" in 'fhn2'
runner(duration=100)

bp.visualize.line_plot(runner.mon.ts, runner.mon['f1.v'], legend='fhn1.v', show=False)
bp.visualize.line_plot(runner.mon.ts, runner.mon['X.v'], legend='fhn2.v', show=True)
_images/437ee545760015a04d75ba129c3c6328e8d0e3da53c853687d18050f02fe8057.png

Dynamics Training#

This tutorial shows how to train a dynamical system from data or task, and how to customize your nodes or networks.

Node Specification#

@Chaoming Wang

Neural networks in BrainPy are used to build dynamical systems. The brainpy.nn module provides various classes representing the nodes of a neural network. All of them are subclasses of the brainpy.nn.Node base class.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

What is a node?#

In BrainPy, the Node instance is the basic element to form a network model. It is a unit on a graph, connected to other nodes by edges.

In general, each Node instance in BrainPy has four components:

  • Feedforward inputs

  • Feedback inputs

  • State

  • Output

It is worthy to note that each Node instance may have multiple feedforward or feedback connections. However, it only has one state and one output. output component is used in feedforward connections and feedback connections, which means the feedforward and feedback outputs are the same. However, customization of a different feedback output is also easy (see the Customization of a Node tutorial).

Each node has the following attributes:

  • feedforward_shapes: the shapes of the feedforward inputs.

  • feedback_shapes: the shapes of the feedback inputs.

  • output_shape: the output shape of the node.

  • state: the state of the node. It can be None if the node has no state to hold.

  • fb_output: the feedback output of the node. It is None when no feedback connections are established to this node. Default, the value of fb_output is the forward() function output value.

It also has several boolean attributes:

  • trainable: whether the node is trainable.

  • is_initialized: whether the node has been initialized.

Creating a node#

A layer can be created as an instance of a brainpy.nn.Node subclass. For example, a dense layer can be created as follows:

bp.nn.Dense(num_unit=100) 
Dense(name=Dense0, forwards=None, 
      feedbacks=None, output=(None, 100))

This will create a dense layer with 100 units.

Of course, if you have known the shapes of the feedforward connections, you can use input_shape.

bp.nn.Dense(num_unit=100, input_shape=128) 
Dense(name=Dense1, forwards=((None, 128),), 
      feedbacks=None, output=(None, 100))

This create a densely connected layer which connected to another input layer with 128 dimension.

Naming a node#

For convenience, you can name a layer by specifying the name keyword argument:

bp.nn.Dense(num_unit=100, input_shape=128, name='hidden_layer')
Dense(name=hidden_layer, forwards=((None, 128),), 
      feedbacks=None, output=(None, 100))

Initializing parameters#

Many nodes have their parameters. We can set the parameter of a node with the following methods.

  • Tensors

If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:

l = bp.nn.Dense(num_unit=50, input_shape=10, 
                weight_initializer=bm.random.normal(0, 0.01, size=(10, 50)))
l.initialize(num_batch=1)

l.Wff.shape
(10, 50)
  • Callable function

If a callable function (which receives a shape argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:

def init(shape):
    return bm.random.random(shape)

l = bp.nn.Dense(num_unit=30, input_shape=20, weight_initializer=init)
l.initialize(num_batch=1)

l.Wff.shape
(20, 30)
  • Instance of brainpy.init.Initializer

If a brainpy.init.Initializer instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:

l = bp.nn.Dense(num_unit=100, input_shape=20, 
                weight_initializer=bp.init.Normal(0.01))
l.initialize(num_batch=1)

l.Wff.shape
(20, 100)

The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).

  • None parameter

Some types of parameter variables can also be set to None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:

l = bp.nn.Dense(num_unit=100, input_shape=20, bias_initializer=None)
l.initialize(num_batch=1)

print(l.bias)
None

Calling the node#

The instantiation of a node build a input-to-output function mapping. To get the mapping output, you can directly call the created node.

l = bp.nn.Dense(num_unit=10, input_shape=20)
l.initialize()
l(bm.random.random((1, 20)))
JaxArray([[ 0.7788163 ,  0.6352515 ,  0.9846623 ,  0.97518134,
           -1.0947354 ,  0.29821265, -0.9927582 , -0.00511351,
            0.6623081 ,  0.72418994]], dtype=float32)
l(bm.random.random((2, 20)))
JaxArray([[ 0.21428639,  0.5546448 ,  0.5172446 ,  1.2533414 ,
           -0.54073226,  0.6578476 , -0.31080672,  0.25883573,
           -0.0466502 ,  0.50195456],
          [ 0.91855824,  0.503054  ,  1.1109638 ,  0.707477  ,
           -0.8442794 , -0.12064239, -0.81839114, -0.2828313 ,
           -0.660355  ,  0.20748737]], dtype=float32)

Moreover, JIT the created model is also applicable.

jit_l = bm.jit(l)
%timeit l(bm.random.random((2, 20)))
2.34 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jit_l(bm.random.random((2, 20)))
2.04 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

trainable settings#

Setting the node to be trainable or non-trainable can be easily achieved. This is controlled by the trainable argument when initializing a node.

For example, for a non-trainable dense layer, the weights and bias are JaxArray instances.

l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=False)
l.initialize(num_batch=1)

l.Wff
JaxArray([[ 0.56564915, -0.70626205,  0.03569109],
          [-0.10908064, -0.63869774, -0.37541717],
          [-0.80857176,  0.22993006,  0.02752776],
          [ 0.32151228, -0.45234612,  0.9239818 ]], dtype=float32)

When creating a layer with trainable setting, TrainVar will be created for them and initialized automatically. For example:

l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=True)
l.initialize(num_batch=1)

l.Wff
TrainVar([[-0.20390746,  0.7101851 , -0.2881384 ],
          [ 0.07779109, -1.1979834 ,  0.09109607],
          [-0.41889605,  0.3983429 , -1.1674007 ],
          [-0.14914905, -1.1085916 , -0.10857478]], dtype=float32)

Moreover, for a subclass of brainpy.nn.RecurrentNode, the state can be set to be trainable or not trainable by state_trainable argument. When setting state_trainable=True for an instance of brainpy.nn.RecurrentNode, a new attribute .train_state will be created.

rnn = bp.nn.VanillaRNN(3, input_shape=(1,), state_trainable=True)
rnn.initialize(3)

rnn.train_state
TrainVar([0.7986958 , 0.3421112 , 0.24420719], dtype=float32)

Note the difference between the .train_state and the original .state:

  1. .train_state has no batch axis.

  2. When using node.init_state() or node.initialize() function, all values in the .state will be filled with .train_state.

rnn.state
Variable([[0.7986958 , 0.3421112 , 0.24420719],
          [0.7986958 , 0.3421112 , 0.24420719],
          [0.7986958 , 0.3421112 , 0.24420719]], dtype=float32)

Node Operations#

@Chaoming Wang

To form a large network, you need to know the supported node operations. In this section, we are going to talk about this.

import brainpy as bp

The Node instance supports the following basic node operations:

  1. feedforward connection: >>, >>=

  2. feedback connection: <<, <<=

  3. merging: & or &=

  4. concatenating: [node1, node2, ...] or (node1, node2, ...)

  5. wraping a set of nodes: {node1, node2, ...}

  6. selection: node[slice] (like “node[1, 2, 3]”, “node[:10]”)

Feedforward operator#

Feedforward connection is the theme of the network construction. To declare a feedforward connection between two nodes, you can use the >> operator.

Users can use node1 >> node2 to create a feedforward connection betweem two nodes. Or, ones can use node1 >>= node2 to in-place connect node2.

i = bp.nn.Input(1)
r = bp.nn.VanillaRNN(10)
o = bp.nn.Dense(1)

model = i >> r >> o

model.plot_node_graph(fig_size=(6, 4),
                      node_size=1000)
_images/94ca8814c56344fe20afebcad91aecf70d7929285b1302182a07dbfa1e9f5032.png

Nodes can be combined in any way to create deeper structure. The >> operator allows to compose nodes to form a sequential model. Data flows from node to node in a sequence. Below are examples of deep recurrent neural networks.

model = (
    bp.nn.Input(1)
    >>
    bp.nn.VanillaRNN(10)
    >>
    bp.nn.VanillaRNN(20)
    >>
    bp.nn.VanillaRNN(10)
    >>
    bp.nn.Dense(1)
)

model.plot_node_graph(fig_size=(6, 4), node_size=500, layout='shell_layout')
_images/5e963bdf92236dbd55a677193b7aab65dea12357dd9a88d47195f70d377f61ae.png

Note

The feedforward connections cannot form a cycle. Otherwise, an error will be raised.

try:
    model = i >> r >> o >> i
except Exception as e:
    print(f'{e.__class__.__name__}: {e}')
ValueError: We detect cycles in feedforward connections. Maybe you should replace some connection with as feedback ones.

Feedback operator#

Feedback connections are important features of reservoir computing. Once a feedback connection is established between two nodes, when running on a timeseries, BrainPy will send the output of the sender, with a time delay of one time-step (however the way of the feedback can be customized by user settings).

To declare a feedback connection between two nodes, you can use the << operator.

model = (i >> r >> o) & (r << o)

model.plot_node_graph(fig_size=(4, 4), node_size=1000)
_images/23bfd47d0c154ccf3811a93c18f862f23ab264b55ba0662ae81da8e236b67d2c.png

Merging operator#

The merging & operator allows to merge models together. Merging two networks will create a new network model containing all nodes and all conenction edges in the two networks.

Some networks may have input-to-readout connections. This can be achieved using the merging operation &.

model = (i >> r >> o) & (i >> o)

model.plot_node_graph(fig_size=(4, 4), node_size=1000)
_images/2e715fc9346446eb708d0b3e59cf797c527fb353ae96d1392956b82c5b81c56c.png

Concatenating operator#

Concatenating operators [] and () will concatenate multiple nodes into one. It can be used in the sender side of a feedforward or feedback connection.

For above input-to-readout connections, we can rewrite it as:

model = [i >> r, i] >> o
# or  
# model = (i >> r, i) >> o

model.plot_node_graph(fig_size=(4, 4), node_size=1000)
_images/c7af54f00700dea3778a3d2272eedf73e92272d0a196111706ac212f2f9bb43f.png

Note

Concatenating multiple nodes in the receiver side will cause errors.

# In the above network, "i" project to "r" and "o" simultaneously.
# However, we cannot express this node graph as
#
#    i >> [r, o]

try:
    model = i >> [r, o]
except Exception as e:
    print(f'{e.__class__.__name__}: {e}')
ValueError: Cannot concatenate a list/tuple of receivers. Please use set to wrap multiple receivers instead.

Wraping operator#

Wrapping a set of nodes {} means that these nodes are equal and they can make the same operation simultaneously.

For example, if the input node “i” project to recurrent node “r” and readout node “o” simultaneously, we can express this graph as

model = i >> {r, o}

model.plot_node_graph(fig_size=(4, 4), node_size=1000)
_images/edbd807e0d23cc3b3f1a674559648385f9291248cc3b56077b9e9f23cffb216a.png

Similarly, if multiple nodes connect to a same node, we can wrap then first and then establish the connections.

model = {i >> r, i} >> o

model.plot_node_graph(fig_size=(4, 4), node_size=1000)
_images/2e715fc9346446eb708d0b3e59cf797c527fb353ae96d1392956b82c5b81c56c.png

Selecting operator#

Sometimes, our input is just a subset of output of a node. For this situation, we can use selection node[] operator.

For example, if we want decode a half of output of the recurrent node “r” by a readout node, and decode the other half of recurrent output by another readout node, we can express this graph as:

o1 = bp.nn.Dense(1)
o2 = bp.nn.Dense(2)

model = i >> r
model = (model[:, :5] >> o1)  & (model[:, 5:] >> o2) # the first is the batch axis

model.plot_node_graph(fig_size=(5, 5), node_size=1000)
_images/8fc3d100b15239d1bd70e1ef95061b62550c8cdaa61232bf1941a19c760f6347.png

Network Running and Training#

@Chaoming Wang

To maker your model powerful, you need to train your created network models. In this section, we are going to talk about how to train and run your network models.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
import matplotlib.pyplot as plt

RNN structural runner RNNRunner#

For a feedforward network, predicting the output of the network just needs to call the instantiated model:

model = ... # your created model

output = model(inputs)

To accelerate the model running, you can jit the model by


import brainpy.math as bm
model = bm.jit(model)  # jitted model

output = model(inputs)

However, for the recurrent network model, you need to call the instantiated model multiple times along the time axis. However, looping in python is very inefficient. Instead, BrainPy provides structural runner brainpy.nn.RNNRunner for the recurrent neural network running. Using brainpy.nn.RNNRunner, the looping process will be jit compiled into machine code, approaching to the speed of native c++ code.

Here we have a reservoir model.

model = (bp.nn.Input(3) >>
         bp.nn.Reservoir(100) >>
         bp.nn.LinearReadout(3))
model.initialize()

And we have a Lorenz attractor data.

lorenz = bp.datasets.lorenz_series(100)
data = bm.hstack([lorenz['x'], lorenz['y'], lorenz['z']])

Our task is to predict the Lorenz data 5 time step ahead.

X, Y = data[:-5], data[5:]

Note, all nn models in BrainPy must have a batch axis at the first dimension of the data.

# here batch size is 1
X = bm.expand_dims(X, axis=0)
Y = bm.expand_dims(Y, axis=0)

X.shape
(1, 99995, 3)

We can output the model predictions according to the input data simply with

runner = bp.nn.RNNRunner(model, jit=True)

predictions = runner.predict(X)
predictions.shape
(1, 99995, 3)
bp.losses.mean_squared_error(predictions, Y)
DeviceArray(260.12122, dtype=float32)

Without training, the mean squared error (MSE) between the prediction and the target is large. We need to train the network hoping it has the ability to produce the correct results.

Supported training algorithms#

Currently, BrainPy provides several kinds of methods to train recurrent neural networks, including

  • offline training algorithms, like ridge regression,

  • online training algorithms, like FORCE learning,

  • back-propagation based algorithms, like BPTT, etc.

The full list of the supported training algorithms please see the API documentation. Here we only talk about few of them.

RNN structural trainer RNNTrainer#

RNNTrainer is a structural trainer to train recurrent neural networks. Actually, it is a subclass of RNNRunner. What’s different from RNNRunner is that the former has one more function .fit() to train the model.

The training data feeding into the .fit() function can be a tuple/list of (X, Y) pair, or a callable function which generate (x, y) data pairs.

  • If the providing training data is the (X, Y) data pair, X should be the input data which has the shape of (num_sample, num_time, ...), Y should be the target data which has the shape of (num_sample, ...).

  • If the training data is a callable function, it should generate a Python generator which yield the pair of (X, Y) data for training. For example,


# when calling this function,
# it will create a Python generator.

def train_data():
  num_data = 10
  for _ in range(num_data):
     # The (X, Y) data pair should be:
     # - "X" is a tensor has the shape of
     #   "(num_batch, num_time, num_feature)"
     # - "Y" is a tensor has the shape of
     #   "(num_batch, num_time, num_feature)"
     #   or "(num_batch, num_feature)"
     xs = bm.random.rand(1, 20, 2)
     ys = bm.random.random((1, 20, 2))
     yield xs, ys

However, all these data constraints can be released when you customize your training procedures. Please see Customization of a Network Training.

Offline training algorithms#

Offline learning means you train your network with all dataset at once. All supported offline learning algorithms are

bp.nn.algorithms.get_supported_offline_methods()
('ridge', 'linear', 'lstsq')

We will continue to update all offline learning methods. New advances please check the corresponding API documentation

Instantiating an offline learning method is simple. Once you have your network model, like the above reservoir model, you just need to provide this model into the brainpy.nn.OfflineTrainer as

model = (bp.nn.Input(3) >>
         bp.nn.Reservoir(100) >>
         bp.nn.LinearReadout(3))
model.initialize()
trainer = bp.nn.OfflineTrainer(
    model,
    fit_method=bp.nn.algorithms.RidgeRegression(beta=1e-6)
    # or
    # fit_method=dict(name='ridge', beta=1e-6)
)

trainer
OfflineTrainer(target=Network(LinearReadout1, Input1, Reservoir1), 
	              jit={'fit': True, 'predict': True}, 
	              fit_method=RidgeRegression(beta=1e-06))

Let’s train the created model with the Lorenz attractor data series.

trainer.fit([X, Y])
predict = trainer.predict(X, reset=True)
predict1 = bm.as_numpy(predict)

fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict1[0, :, 0], predict1[0, :, 1], predict1[0, :, 2])
plt.title('Trained with Ridge Regression')
plt.show()
_images/63399a624daf4ca3d7f0399386ffffa66aac2a37c081db35e438b5177ebf00e2.png

Online training algorithms#

BrainPy also supports flexible online training methods. Online learning means you train the model from a sequence of data instances one at a time. The representative of online learning algorithm for recurrent neural network is the force learning. Here let’s try to train the above reservoir model with the force learning algorithm.

model = (bp.nn.Input(3) >>
         bp.nn.Reservoir(100) >>
         bp.nn.LinearReadout(3))
model.initialize()
trainer = bp.nn.OnlineTrainer(
    model,
    fit_method=bp.nn.algorithms.ForceLearning(alpha=0.1)
    # or
    # fit_method=dict(name='force', alpha=1e-1)
)

trainer
OnlineTrainer(target=Network(Input2, Reservoir2, LinearReadout2), 
	             jit={'fit': True, 'predict': True}, 
	             fit_method=ForceLearning)
trainer.fit([X, Y])
predict2 = trainer.predict(X, reset=True)
predict2 = bm.as_numpy(predict2)

fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict2[0, :, 0], predict2[0, :, 1], predict2[0, :, 2])
plt.title('Trained with Force Learning')
plt.show()
_images/4c36289c478ab86bfb793ce8c7eac567f2112e52ae291fe7c6a1a794d0b6ea16.png

Back-propagation algorithm#

In recent years, back-propagation has become a powerful method to train recurrent neural network. BrainPy also support trains networks with back-propagation algorithms.

reservoir = (bp.nn.Input(3) >>
             bp.nn.Reservoir(100))
reservoir.initialize()
# The reservoir node is not trainable.
# Therefore, we generate the reservoir output
# data as the input data to train readout node.

runner = bp.nn.RNNRunner(reservoir)
projections = runner.predict(X)
# For linear readout node, it is not a recurrent node.
# There is no need to keep time axis.
# Therefore, we make the original time step as the sample size.

projections = projections[0]
targets = Y[0]
readout = bp.nn.Dense(3, input_shape=100)
readout.initialize()
# Training the readout node with the back-propagation method.
# Due to the Dense node is a feedforward node, we use BPTT trainer.

trainer = bp.nn.BPFF(readout,
                     loss=bp.losses.mean_squared_error,
                     optimizer=bp.optim.Adam(lr=1e-3))
trainer
BPFF(target=Dense(name=Dense0, forwards=((None, 100),), 
      feedbacks=None, output=(None, 3)), 
	    jit={'fit': True, 'predict': True, 'loss': True}, 
	    loss=<function mean_squared_error at 0x0000021BC021BC10>, 
	    optimizer=Adam(lr=Constant(0.001), beta1=0.9, beta2=0.999, eps=1e-08))
trainer.fit([projections, targets],
            num_report=2000,
            num_batch=64,
            num_train=40)
Train 2000 steps, use 3.4620 s, train loss 31.42259
Train 4000 steps, use 2.4532 s, train loss 14.43684
Train 6000 steps, use 2.3870 s, train loss 10.64222
Train 8000 steps, use 2.4791 s, train loss 13.16424
Train 10000 steps, use 2.3759 s, train loss 7.0941
Train 12000 steps, use 2.3584 s, train loss 7.70877
Train 14000 steps, use 2.3648 s, train loss 8.33284
Train 16000 steps, use 2.4334 s, train loss 3.79623
Train 18000 steps, use 2.3502 s, train loss 3.86504
Train 20000 steps, use 2.3463 s, train loss 3.96748
Train 22000 steps, use 2.4486 s, train loss 3.88499
Train 24000 steps, use 2.3902 s, train loss 2.47998
Train 26000 steps, use 2.3854 s, train loss 1.69119
Train 28000 steps, use 2.3613 s, train loss 1.85288
Train 30000 steps, use 2.4531 s, train loss 1.77884
Train 32000 steps, use 2.3742 s, train loss 1.95193
Train 34000 steps, use 2.3862 s, train loss 1.6745
Train 36000 steps, use 2.4662 s, train loss 1.20792
Train 38000 steps, use 2.3957 s, train loss 1.55736
Train 40000 steps, use 2.3752 s, train loss 1.36623
Train 42000 steps, use 2.3872 s, train loss 1.09453
Train 44000 steps, use 2.4989 s, train loss 0.97422
Train 46000 steps, use 2.3895 s, train loss 0.70705
Train 48000 steps, use 2.4091 s, train loss 0.8673
Train 50000 steps, use 2.3833 s, train loss 1.12951
Train 52000 steps, use 2.4962 s, train loss 1.20924
Train 54000 steps, use 2.3950 s, train loss 0.79635
Train 56000 steps, use 2.3883 s, train loss 0.62906
Train 58000 steps, use 2.4581 s, train loss 0.91307
Train 60000 steps, use 2.4038 s, train loss 0.74997
Train 62000 steps, use 2.4042 s, train loss 1.04045
# plot the training loss

plt.plot(trainer.train_losses.numpy())
plt.show()
_images/81d5c3c12dd55f7d7fcc750ccb797db5691e46184e518378c47430549deb9363.png
# Finally, let's make the full model in which
# reservoir node generates the high-dimensional
# projection data, and then the linear readout
# node readout the final value.

model = reservoir >> readout
model.initialize()

runner = bp.nn.RNNRunner(model)
predict3 = runner.predict(X)
predict3 = bm.as_numpy(predict3)

fig = plt.figure(figsize=(5, 5))
fig.add_subplot(111, projection='3d')
plt.plot(predict3[0, :, 0], predict3[0, :, 1], predict3[0, :, 2])
plt.title('Trained with BPTT')
plt.show()
_images/c6f7cb30051fcfe124fa436dc4993cb21d26987a844579f487ab5e7dea44571f.png

Shared parameters#

Sometimes, there are some global parameters which are shared across all nodes. For example, the training or testing phase control parameter train=True/False. Here, we use one simple model to demonstrate how to provide shared parameters when we calling models.

model = (
    bp.nn.Input(1)
    >>
    bp.nn.VanillaRNN(100)
    >>
    bp.nn.Dropout(0.3)
    >>
    bp.nn.Dense(1)
)
model.initialize(3)

These shared parameters can be provided as two kinds of ways:

  • When you are using the instantiated model directly, you can provide them when calling this model.

model(bm.random.rand(3, 1), train=True)
JaxArray([[-1.2080045],
          [-0.962251 ],
          [ 0.246601 ]], dtype=float32)
model(bm.random.rand(3, 1), train=False)
JaxArray([[-0.18471804],
          [-0.11392485],
          [-0.13624835]], dtype=float32)
  • When you are using the structural runners like brainpy.nn.RNNRunner or brainpy.nn.BPTT trainer, you can warp all shared parameters in an argument shared_kwargs.

runner = bp.nn.RNNRunner(model)
runner.predict(bm.random.random((3, 10, 1)),
               shared_kwargs={'train': True})
JaxArray([[[-0.3159347 ],
           [-0.69149274],
           [-0.04672527],
           [ 0.03180977],
           [-0.06807568],
           [-0.13523842],
           [ 0.01571239],
           [-0.11823184],
           [ 0.12058208],
           [-0.17275347]],

          [[ 0.0180111 ],
           [-0.12634276],
           [-0.32290417],
           [-0.16321549],
           [-0.05132714],
           [ 0.08687519],
           [-0.12866825],
           [-0.3837371 ],
           [-0.3020746 ],
           [-0.1423104 ]],

          [[-0.414655  ],
           [-0.496073  ],
           [-0.4937666 ],
           [-0.04079266],
           [ 0.04316711],
           [ 0.11759105],
           [-0.59218377],
           [ 0.14002447],
           [-0.27708793],
           [-0.10970033]]], dtype=float32)
runner.predict(bm.random.random((3, 10, 1)),
               shared_kwargs={'train': False})
JaxArray([[[-0.2004511 ],
           [-0.02842245],
           [-0.05291707],
           [-0.00817785],
           [-0.1658831 ],
           [-0.0154308 ],
           [-0.08032076],
           [-0.02801216],
           [-0.06928631],
           [-0.02795052]],

          [[-0.01163506],
           [-0.12711151],
           [-0.01078814],
           [-0.04324045],
           [-0.14794606],
           [-0.09333474],
           [-0.0649181 ],
           [-0.02171569],
           [-0.07023487],
           [-0.06169168]],

          [[-0.02836527],
           [-0.02502684],
           [-0.16531822],
           [-0.02565872],
           [ 0.00313345],
           [-0.02255425],
           [-0.20593695],
           [-0.11946493],
           [-0.13288933],
           [-0.07359352]]], dtype=float32)

However, it’s worthy to note that shared_kwargs should only have several values. Different value of shared_kwargs will trigger recompilation. If shared_kwargs change significantly and frequently, you’d better not declare it as shared_kwargs.

Node Customization#

@Chaoming Wang

To implement a custom node in BrainPy, you will have to write a Python class that subclasses brainpy.nn.Node and implement several important methods.

import brainpy as bp
import brainpy.math as bm

from brainpy.tools.checking import check_shape_consistency

bp.math.set_platform('cpu')

Before we start, you need to know the minimal knowledge about the brainpy.nn.Node. Please see the tutorial of Node Specification.

Customizing a feedforward node#

In general, the variable initialization and the logic computation of each node in brainpy.nn module are separated from each other. If not, applying JIT compilation to these nodes will be difficult.

If your node only has feedforward connections,

you need to implement two functions:

  • init_ff(): This function aims to initialize the feedforward connections and compute the output shape according to the given feedforward_shapes.

  • forward(): This function implement the main computation logic of the node. It may calculate the new state of the node. But most importantly, this function shoud return the output value for feedforward data flow.

To show how this can be used, here is a node that multiplies its input by a matrix W (much like a typical fully connected layer in a neural network would). This matrix is a parameter of the layer. The shape of the matrix will be (num_input, num_unit), where num_input is the number of input features and num_unit is the number of output features.

class DotNode(bp.nn.Node):
    def __init__(self, num_unit, W_initializer=bp.initialize.Normal(), **kwargs):
        super(DotNode, self).__init__(**kwargs)
        self.num_unit = num_unit
        self.W_initializer = W_initializer
    
    def init_ff(self):
        # This function should compute the output shape and  
        # the feedforward (FF) connections
        
        # 1. First, due to multiple FF shapes, we need to know 
        #    the total shape when all FF inputs are concatenated. 
        #    Function "check_shape_consistency()" may help you 
        #    solve this problem quickly.
        
        unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
        
        # 2. Initialize the weight W
        weight_shape = (sum(free_sizes), self.num_unit)
        self.W = bp.nn.init_param(self.W_initializer, weight_shape)
        #   If the user want to train this node, we need mark the 
        #   weight as a "brainpy.math.TrainVar"
        if self.trainable:
            self.W = bm.TrainVar(self.W)
        
        # 3. Set the output shape 
        self.set_output_shape(unique_size + (self.num_unit,))
        
    def forward(self, ff):
        # 1. First, we concatenate all FF inputs
        ff = bm.concatenate(ff, axis=-1)
        
        # 2. Then, we multiply the input with the weight
        return bm.dot(ff, self.W)

A few things are worth noting here: when overriding the constructor, we need to call the superclass constructor on the first line. This is important to ensure the node functions properly. Note that we pass **kwargs - although this is not strictly necessary, it enables some other cool features, such as making it possible to give the layer a name:

DotNode(10, name='my_dot_node')
DotNode(name=my_dot_node, trainable=False, forwards=None, feedbacks=None, 
        output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)

Or, set this node trainable:

DotNode(10, trainable=True)
DotNode(name=DotNode0, trainable=True, forwards=None, feedbacks=None, 
        output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)

Once we create this DotNode, we can connect multiple feedforward nodes to its instance.

l = DotNode(10)
i1 = bp.nn.Input(1, name='i1')
i2 = bp.nn.Input(2, name='i2')
i3 = bp.nn.Input(3, name='i3')

net = {i1, i2, i3} >> l

net.plot_node_graph(fig_size=(4, 4), node_size=2000)
_images/4ff6947f28619785f97751fc7c8bdc3b1fd2e10f265f7d90355ae94ff75833b7.png
net.initialize(num_batch=1)

# given an input, let's compute its output
net({'i1': bm.ones((1, 1)), 
     'i2': bm.zeros((1, 2)), 
     'i3': bm.random.random((1, 3))})
JaxArray([[-0.41227022, -1.2145127 ,  1.2915486 , -1.7037894 ,
            0.47149402, -1.9161812 ,  1.3631151 , -0.4410456 ,
            1.9460022 ,  0.54992586]], dtype=float32)

Customizing a recurrent node#

If your node is a recurrent node, which means it has its own state and has a self-to-self connection weights,

this time, you need to implement one more function:

  • init_state(num_batch): This function aims to initialize the Node state which depends on the batch size.

Furthermore, we recommend users’ recurren node inherit from brainpy.nn.RecurrentNode. Because this will instruct BrainPy to know it is a node has recurrent connections.

Here, let’s try to implement a Vanilla RNN model.

class VRNN(bp.nn.RecurrentNode):
  def __init__(self, num_unit, 
               wi_initializer=bp.init.XavierNormal(),
               wr_initializer=bp.init.XavierNormal(), **kwargs):
    super(VRNN, self).__init__(**kwargs)
    
    self.num_unit = num_unit
    self.wi_initializer = wi_initializer
    self.wr_initializer = wr_initializer
    
  def init_ff(self):
    unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
    num_input = sum(free_sizes)
    self.wi = bp.nn.init_param(self.wi_initializer, (num_input, self.num_unit))
    self.wr = bp.nn.init_param(self.wr_initializer, (self.num_unit, self.num_unit))
    if self.trainable:
      self.wi = bm.TrainVar(self.wi)
      self.wr = bm.TrainVar(self.wr)
    
  def init_state(self, num_batch=1):
    state = bm.zeros((num_batch, self.num_unit))
    self.set_state(state)
  
  def forward(self, ff):
    ff = bm.concatenate(ff, axis=-1)
    state = ff @ self.wi + self.state @ self.wr
    self.state.value = state
    return state

Customizing a node with feedbacks#

Creating a layer receiving multiple feedback inputs is the same with the feedforward connections.

Users need to implement one more function, that is:

  • init_fb(): This function aims to initialize the feedback information, including the feedback connections, feedback weights, and others.

For the above DotNode, if try to support feedback connection, you can define the model like:

class FeedBackDotNode(bp.nn.Node):
  def __init__(self, num_unit, W_initializer=bp.initialize.Normal(), **kwargs):
    super(FeedBackDotNode, self).__init__(**kwargs)
    self.num_unit = num_unit
    self.W_initializer = W_initializer

  def init_ff(self):
    # 1. FF shapes
    unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)
    # 2. Initialize the feedforward weight Wff
    weight_shape = (sum(free_sizes), self.num_unit)
    self.Wff = bp.nn.init_param(self.W_initializer, weight_shape)
    if self.trainable:
      self.Wff = bm.TrainVar(self.Wff)
    # 3. Set the output shape 
    self.set_output_shape(unique_size + (self.num_unit,))
  
  def init_fb(self):
    # 1. FB shapes
    unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)
    # 2. Initialize the feedback weight Wfb
    weight_shape = (sum(free_sizes), self.num_unit)
    self.Wfb = bp.nn.init_param(self.W_initializer, weight_shape)
    if self.trainable:
      self.Wfb = bm.TrainVar(self.Wfb)
    
  def forward(self, ff, fb=None):
    ff = bm.concatenate(ff, axis=-1)
    res = bm.dot(ff, self.Wff)
    if fb is None:
      fb = bm.concatenate(fb, axis=-1)
      res += bm.dot(fb, self.Wfb)
    return res

Note the difference between DotNode and FeedBackDotNode. The forward() function of the latter has one argument fb=None, which means if this node has feedback connections, it will pass all feedback inputs to fb argument.

Note

Feedback connecting to a node which do not support feedbacks will raise an error.

try:
    DotNode(1) << bp.nn.Input(1)
except Exception as e:
    print(e.__class__)
    print(e)
<class 'ValueError'>
Establish a feedback connection to 
DotNode(name=DotNode2, trainable=False, forwards=None, feedbacks=None, 
        output=None, support_feedback=False, data_pass_type=PASS_SEQUENCE)
is not allowed. Because this node does not support feedback connections.
FeedBackDotNode(1) << bp.nn.Input(1)
Network(FeedBackDotNode0, Input1)

Customizing a node with multiple behaviors#

Some nodes can have multiple behaviors. For example, a node implementing dropout should be able to be switched on or off. During training, we want it to apply dropout noise to its input and scale up the remaining values, but during evaluation we don’t want it to do anything.

For this purpose, the forward() method takes optional keyword arguments (kwargs). When forward() is called to compute an expression for the output of a network, all specified keyword arguments are passed to the forward() methods of all layers in the network.

class Dropout(bp.nn.Node):
    def __init__(self, prob, seed=None, **kwargs):
        super(Dropout, self).__init__(**kwargs)
        self.prob = prob
        self.rng = bm.random.RandomState(seed=seed)

    def init_ff(self):
        assert len(self.feedback_shapes) == 1, 'Only support one feedforward input.'
        self.set_output_shape(self.feedforward_shapes[0])

    def forward(self, ff, **kwargs):
        assert len(ff) == 1, 'Only support one feedforward input.'
        if kwargs.get('train', True):
            keep_mask = self.rng.bernoulli(self.prob, ff[0].shape)
            return bm.where(keep_mask, ff[0] / self.prob, 0.)
        else:
            return ff[0]

Dropout node only supports one feedforward input. Therefore we have some check at the beginning of init_ff() and forward() functions.

Customization of a Network Training#

Dynamics Analysis#

Low-dimensional Analyzers#

@Chaoming Wang

We have talked about model simulation and training for dynamical systems with BrainPy. In this tutorial, we are going to dive into how to perform automatic analysis for your defined systems.

As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:

  1. phase plane analysis;

  2. codimension 1 or codimension 2 bifurcation analysis;

  3. bifurcation analysis of the fast-slow system.

BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
bp.math.enable_x64()  # It's better to enable x64 when performing analysis
import numpy as np
import matplotlib.pyplot as plt

A simple case#

Here we test BrainPy with a simple case:

\[ \frac{dx}{dt} = \mathrm{sin}(x) + I, \]

where \(x \in [-10, 10]\).

As known to us all, this functuon has multiple fixed points (\(\frac{dx}{dt} = 0\)) when \(I=0\).

xs = np.arange(-10, 10, 0.01)

plt.plot(xs, np.sin(xs))
plt.scatter([-3*np.pi, -1*np.pi, 1*np.pi, 3 * np.pi], np.zeros(4), s=80, edgecolors='y')
plt.scatter([-2*np.pi, 0, 2*np.pi], np.zeros(3), s=80, facecolors='none', edgecolors='r')
plt.axhline(0)
plt.show()
_images/8086b43cdb3fd005feefd738bd2329900983ab6dbdd4195915b3abe2dcea7f43.png

According to the dynamical theory, at the red hollow points, they are unstable; and for the solid ones, they are stable points.

Now let’s come back to BrainPy, and test whether BrainPy can give us the right answer.

As the analysis interfaces in BrainPy only receives ODEIntegrator or instance of DynamicalSystem, we first define an integrator with BrainPy (if you want to know how to define an ODE integrator, please refer to the tutorial of Numerical Solvers for ODEs):

@bp.odeint
def int_x(x, t, Iext):
    return bp.math.sin(x) + Iext

This is a one-dimensional dynamical system. So we are trying to use brainpy.analysis.PhasePlane1D for phase plane analysis. The usage of phase plane analysis will be detailed in the following section. Now, we just focus on the following four arguments:

  • model: It specifies the target system to analyze. It can be a list/tuple of ODEIntegrator. However, it can also be an instance of DynamicalSystem. For DynamicalSystem argument, we will use model.ints().subset(bp.ode.ODEIntegrator) to retrieve all instances of ODEIntegrator later.

  • target_vars: It specifies the variables to analyze. It must be a dict with the format of <var_name, var_interval>, where var_name is the variable name, and var_interval is the boundary of this variable.

  • pars_update: Parameters to update.

  • resolutions: The resolution to evaluate the fixed points.

Let’s try it.

pp = bp.analysis.PhasePlane1D(
  model=int_x,
  target_vars={'x': [-10, 10]},
  pars_update={'Iext': 0.},
  resolutions=0.001
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)
I am creating vector fields ...
I am searching fixed points ...
Fixed point #1 at x=-9.42477796076938 is a stable point.
Fixed point #2 at x=-6.283185307179586 is a unstable point.
Fixed point #3 at x=-3.141592653589793 is a stable point.
Fixed point #4 at x=9.237056486678452e-19 is a unstable point.
Fixed point #5 at x=3.141592653589793 is a stable point.
Fixed point #6 at x=6.283185307179586 is a unstable point.
Fixed point #7 at x=9.42477796076938 is a stable point.
_images/1f7f60b40ab3e46d74600ac0d2ed5ef3b4a16150fe3d60177d070960e43a3400.png

Yeah, absolutelty, brainpy.analysis.PhasePlane1D gives us the right fixed points, and correctly evalutes the stability of these fixed points.

Phase plane is important, because it give us the intuitive understanding how the system evolves with the given parameters. However, in most cases where we care about how the parameters affect the system behaviors, we should make bifurcation analysis. brainpy.analysis.Bifurcation1D is a convenient interface to help you get the insights of how the dynamics of a 1D system changes with parameters.

Similar to brainpy.analysis.PhasePlane1D, brainpy.analysis.Bifurcation1D receives arguments like “model”, “target_vars”, “pars_update”, and “resolutions”. Besides, one more important argument “target_pars” should be provided, which specifies the range of the target parameter in bifurcation analysis.

Here, we systematically change the parameter “Iext” from 0 to 1.5. According to the bifurcation theory, we know this simple system has a fold bifurcation when \(I=1.0\). Because at \(I=1.0\), two fixed points collide with each other into a saddle point and then disappear. Does BrainPy’s analysis toolkit brainpy.analysis.Bifurcation1D is capable of performing these analysis? Let’s make a try.

bif = bp.analysis.Bifurcation1D(
    model=int_x,
    target_vars={'x': [-10, 10]},
    target_pars={'Iext': [0., 1.5]},
    resolutions=0.001
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/1ab945dd7db0e7ad5fe8174c6a84874860aedd5f62357387e7e1b57a83d9a202.png

Once again, BrainPy analysis toolkit gives the right answer. It tells us how does the fixed points evolve when the parameter \(I\) is increasing.

It is worthy to note that bifurcation analysis in BrainPy is hard to find out the saddle point (when \(I=0\) for this system). This is because the saddle point at the bifurcation just exists at a moment. While the numerical method used in BrainPy analysis toolkit is almost impossible to evaluate the point exactly at the saddle. However, if the user has the minimal knowledge about the bifurcation theory, saddle point (the collision point of two fixed points) can be easily infered from the fixed point evolution.

BrainPy’s analysis toolkit is highly useful, especially when the mathematical equations are too complex to get analytical solutions. The example please refer to the tutorial Anlysis of A Decision Making Model.

Phase plane analysis#

Phase plane analysis is one of the most important techniques for studying the behavior of nonlinear systems, since there is usually no analytical solution for a nonlinear system. BrainPy can help users to plot phase plane of 1D systems or 2D systems. Specifically, we provides brainpy.analysis.PhasePlane1D and brainpy.analysis.PhasePlane2D. It can help to plot:

  • Nullcline: The zero-growth isoclines, such as \(g(x, y)=0\) and \(g(x, y)=0\).

  • Fixed points: The equilibrium points of the system, which are located at all of the nullclines intersect.

  • Vector field: The vector field of the system.

  • Limit cycles: The limit cycles.

  • Trajectories: A simulation trajectory with the given initial values.

We have talked about brainpy.analysis.PhasePlane1D in above. Now we focus on brainpy.analysis.PhasePlane2D by using a well-known neuron model FitzHugh-Nagumo model.

The FitzHugh-Nagumo model is given by:

\[\begin{split} \frac {dV} {dt} = V(1 - \frac {V^2} 3) - w + I_{ext} \\ \tau \frac {dw} {dt} = V + a - b w \end{split}\]

There are two variables \(V\) and \(w\), so this is a two-dimensional system with three parameters \(a, b\) and \(\tau\).

For the system to analyze, users can define it by using the pure brainpy.odeint or define it as a class of DynamicalSystem. For this FitzHugh-Nagumo model, we define it as a class because later we will perform simulation to verify the analysis results.

class FitzHughNagumoModel(bp.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(FitzHughNagumoModel, self).__init__()

    # parameters
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5

    # variables
    self.V = bm.Variable(bm.zeros(1))
    self.w = bm.Variable(bm.zeros(1))
    self.Iext = bm.Variable(bm.zeros(1))

    # functions
    def dV(V, t, w, Iext=0.): 
        return V - V * V * V / 3 - w + Iext
    def dw(w, t, V, a=0.7, b=0.8): 
        return (V + a - b * w) / self.tau
    self.int_V = bp.odeint(dV, method=method)
    self.int_w = bp.odeint(dw, method=method)

  def update(self, _t, _dt):
    self.V.value = self.int_V(self.V, _t, self.w, self.Iext, _dt)
    self.w.value = self.int_w(self.w, _t, self.V, self.a, self.b, _dt)
    self.Iext[:] = 0.
model = FitzHughNagumoModel()

Here we perform a phase plane analysis with parameters \(a=0.7, b=0.8, \tau=12.5\), and input \(I_{ext} = 0.8\).

pp = bp.analysis.PhasePlane2D(
  model,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  pars_update={'Iext': 0.8}, 
  resolutions=0.01,
)
# By defaut, nullclines will be plotted as points, 
# while we can set the plot style as the line
pp.plot_nullcline(x_style={'fmt': '-'}, y_style={'fmt': '-'})

# Vector field can plotted as two ways:
# - plot_method="streamplot" (default)
# - plot_method="quiver"
pp.plot_vector_field()

# There are many ways to search fixed points. 
# By default, it will use the nullcline points of the first 
# variable ("V") as the initial points to perform fixed point searching
pp.plot_fixed_point()

# Trajectory plotting receives the setting of the initial points.
# There may be multiple trajectories, therefore the initial points 
# should be provived as a list/tuple/numpy.ndarray/JaxArray
pp.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)

# show the phase plane figure
pp.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating vector fields ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 866 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 V=-0.2738719079879798, w=0.5329731346879486 is a unstable node.
I am plot trajectory ...
_images/77639afa5df92809100351151f15d0770ddc10a71469f4b8c27956d1264024f6.png

We can see an unstable-node at the point (\(V=-0.27, w=0.53\)) inside a limit cycle.

We can run a simulation with the same parameters and initial values to verify the periodic activity that correspond to the limit cycle.

runner = bp.StructRunner(model, monitors=['V', 'w'], inputs=['Iext', 0.8])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)
_images/7d51ac89ecb476f69ff1171eba4da262d764d7d3a9cbbcf3cbf142458a3fb9d1.png

Understanding settings#

There are several key settings needed to understand.

resolutions#

resolutions is one of the most important parameters in PhasePlane and Bifurcation analysis toolkits of BrainPy. It is very important because it has a profound impact on the efficiency of model analysis.

We can set resolutions with the following ways.

  1. None. If we detect there is no resolution setting for any variable, the corresponding resolution for this variable will be \(\frac{\mathrm{max\_value} - \mathrm{min\_value}}{20}\).

  2. A float. It sets a same resolution for each target variable and parameter.

  3. A dict. Specify different resolutions for individual variable/parameter. It can be a float, or a vector with the format of JaxArray or numpy.ndarray.

Enabling set resolutions with a tensor will give the user the maximal flexibility. Usually, the numerical alalysis does not work well at inflection points. Therefore, we can increase the granularity near the inflection points. For example, if there is an inflextion point at \(1\), we can set the resolution with:

r1 = bm.arange(0.00, 0.95, 0.01)
r2 = bm.arange(0.95, 1.01, 0.001)
r3 = bm.arange(1.05, 1.50, 0.01)
resolution = bm.concatenate([r1, r2, r3])

Tips: For bifurcation analysis, usually we need set a small resolution for parameters, leaving the resolutions of variables as the default. Please see in the following examples.

vars and pars#

What can be set as variables *_vars or parameters *_pars (such as target_vars or target_pars) for further analysis? Actually, the variables and parameters are recognized as the same with the programming paradigm of ODE numerical integrators. Simply speaking, the arguments before t will be defined as variables, while arguments after t will be parameters.

BrainPy’s analysis toolkit only support one variable in one differential equation. It cannot analyze the joint differential equation in which multiple variables are defined in the same function.

Moreover, the low-dimensional analyzers in BrainPy cannot analyze dynamical system depends on time \(t\).

Bifurcation analysis#

Nonlinear dynamical systems are characterized by its parameters. When the parameter changes, the system’s behavior will change qualitatively. Therefore, we take care of how the system changes with the smooth change of parameters.

Codimension 1 bifurcation analysis

We will first see the codimension 1 bifurcation anlysis of the model. For example, we vary the input \(I_{ext}\) between 0 to 1 and see how the system change it’s stability.

analyzer = bp.analysis.Bifurcation2D(
  model,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  target_pars={'Iext': [0., 1.]},
  resolutions={'Iext': 0.002},
)

# "num_rank" specifies the number of initial poinits for
# fixed point optimization under a set of parameters
analyzer.plot_bifurcation(num_rank=10)

# show figure
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 5000 candidates
I am trying to filter out duplicate fixed points ...
	Found 500 fixed points.
_images/91bdaeef5979d7a9a02cd3e71ec4560a1e7f71cda3b67b96ccf6c8c27226af09.png _images/0ecc33e3d1c9af65305da753599d91220116bc9ed1e1ead466651088f1070c01.png

Codimension 2 bifurcation analysis

We simulaneously change \(I_{ext}\) and parameter \(a\).

analyzer = bp.analysis.Bifurcation2D(
    model,
    target_vars=dict(V=[-3, 3], w=[-3., 3.]),
    target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
    resolutions={'a': 0.01, 'Iext': 0.01},
)
analyzer.plot_bifurcation(num_rank=10, tol_aux=1e-9)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 50000 candidates
I am trying to filter out duplicate fixed points ...
	Found 5000 fixed points.
_images/d713d7e2adc2c991cf9d70422cf89a2ef42f4b702e2ef2a22c69e292d6544859.png _images/990151444e374d13ede63f991cef48b9a3870d630be7135c46bc5216c4201928.png

Fast-slow system bifurcation#

BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.analysis.FastSlow1D and brainpy.analysis.FastSlow2D. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.

Fast-slow bifurcation methods are very usefull in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:

\[\begin{split} \begin{aligned} \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I\\ \frac{d y}{d t} &= c - d V^2 - y\\ \frac{d z}{d t} &= r (s (V - V_{rest}) - z) \end{aligned} \end{split}\]

First of all, let’s define the Hindmarsh–Rose model with BrainPy.

a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9


@bp.odeint
def int_x(x, t, y, z, Isyn):
    return y - a * x ** 3 + b * x * x - z + Isyn

@bp.odeint
def int_y(y, t, x):
    return c - d * x * x - y

@bp.odeint
def int_z(z, t, x):
    return r * (s * (x - x_r) - z)

We now can start to analysis the underlying bifurcation mechanism.

analyzer = bp.analysis.FastSlow2D(
  [int_x, int_y, int_z],
  fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
  slow_vars={'z': [-5., 5.]},
  pars_update={'Isyn': 0.5},
  resolutions={'z': 0.01}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 20000 candidates
I am trying to filter out duplicate fixed points ...
	Found 1232 fixed points.
_images/e2f6b929d476e8afa7d6207280a4747832aa26596764d9681ebce3f6cb34ff71.png _images/fb5b319e2069eac66318941471d3673198595f3efc95458a0432ff0e51594c64.png

Advanced tutorial: how does the analysis works#

In this section, we provide a basic tutorial to understand how does the brainpy.analysis.LowDimAnalyzer works.

Terminology#

Given the above FitzHugh-Nagumo model, we define an analyzer,

analyzer = bp.analysis.PhasePlane2D(
  [model.int_V, model.int_w],
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  resolutions=0.01,
)

In this instance of brainpy.analysis.LowDimAnalyzer, we use the following terminologies.

  • x_var and y_var are defined by the order of the user setting. If the user sets the “target_vars” as “{‘V’: …, ‘w’: …}”, x_var and y_var will be “V” and “w” respectively. Otherwise, if “target_vars”=”{‘w’: …, ‘V’: …}”, x_var and y_var will be “w” and “V” respectively.

analyzer.x_var, analyzer.y_var
('V', 'w')
  • fx and fy are defined as differential equations of x_var and y_var respectively, i.e.,

fx is

def dV(V, t, w, Iext=0.): 
    return V - V * V * V / 3 - w + Iext

fy is

def dw(w, t, V, a=0.7, b=0.8): 
    return (V + a - b * w) / self.tau
analyzer.F_fx, analyzer.F_fy
(<CompiledFunction of <function f_without_jaxarray_return.<locals>.f2 at 0x000002240FC114C0>>,
 <CompiledFunction of <function f_without_jaxarray_return.<locals>.f2 at 0x000002240FC118B0>>)
  • int_x and int_y are defined as integral functions of the differential equations for x_var and y_var respectively.

analyzer.F_int_x, analyzer.F_int_y
(functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x000002240FC11EE0>),
 functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x000002240FC11F70>))
  • x_by_y_in_fx and y_by_x_in_fx: They denote that x_var and y_var can be separated from each other in “fx” nullcline function. Specifically, x_by_y_in_fx or y_by_x_in_fx denotes \(x = F(y)\) or \(y = F(x)\) accoording to \(f_x=0\) equation. For example, in the above FitzHugh-Nagumo model, \(w\) can be easily represented by \(V\) when \(\mathrm{dV(V, t, w, I_{ext})} = 0\), i.e., y_by_x_in_fx is \(w= V - V ^3 / 3 + I_{ext}\).

  • Similarly, x_by_y_in_fy (\(x=F(y)\)) and y_by_x_in_fy (\(y=F(x)\)) denote x_var and y_var can be separated from each other in “fy” nullcline function. For example, in the above FitzHugh-Nagumo model, y_by_x_in_fy is \(w= \frac{V + a}{b}\), and x_by_y_in_fy is \(V= b * w - a\).

  • x_by_y_in_fx, y_by_x_in_fx, x_by_y_in_fy and y_by_x_in_fy can be set in the options argument.

Mechanism for 1D system analysis#

In order to understand the adavantages and disadvantages of BrainPy’s analysis toolkit, it is better to know the minimal mechanism how brainpy.analysis works.

The automatic model analysis in BrainPy heavily relies on numerical optimization methods, including Brent’s method and BFGS method. For example, for the above one-dimensional system (\(\frac{dx}{dt} = \mathrm{sin}(x) + I\)), after the user sets the resolution to 0.001, we will get the evaluation points according to the variable boundary [-10, 10].

bp.math.arange(-10, 10, 0.001)
JaxArray(DeviceArray([-10.   ,  -9.999,  -9.998, ...,   9.997,   9.998,   9.999],            dtype=float64))

Then, BrainPy filters out the candidate intervals in which the roots lie in. Specifically, it tries to find all intervals like \([x_1, x_2]\) where \(f(x_1) * f(x_2) \le 0\) for the 1D system \(\frac{dx}{dt} = f(x)\).

For example, the following two points which have opposite signs are candidate points we want.

def plot_interval(x0, x1, f):
    xs = np.linspace(x0, x1, 100)
    plt.plot(xs, f(xs))
    plt.scatter([x0, x1], f(np.asarray([x0, x1])), edgecolors='r')
    plt.axhline(0)
    plt.show()
plot_interval(-0.001, 0.001, lambda x: np.sin(x))
_images/c8a4f124cdaf5e8efa90adf262a133dc8b303724c2c343bae0487fcbf329452f.png

According to the intermediate value theorem, there must be a solution between \(x_1\) and \(x_2\) when \(f(x_1) * f(x_2) \le 0\).

Based on these candidate intervals, BrainPy uses Brent’s method to find roots \(f(x) = 0\). Further, after obtain the value of the root, BrainPy uses automatic differentiation to evaluate the stability of each root solution.

Overall, BrainPy’s analysis toolkit shows significant advantages and disadvantages.

Pros: BrainPy uses numerical methods to find roots and evaluate their stabilities, it does not case about how complex your function is. Therefore, it can apply to general problems, including any 1D and 2D dynamical systems, and some part of low-dimensional (\(\ge 3\)) dynamical systems (see later sections). Especially, BrainPy’s analysis toolkit is highly useful when the mathematical equations are too complex to get analytical solutions (the example please refer to the tutorial Anlysis of A Decision Making Model).

Cons: However, numerical methods used in BrainPy are hard to find fixed points only exist at a moment. Moreover, when resolution is small, there will be large amount of calculating. Users should pay attention to designing suitable resolution settings.

Mechanism for 2D system analysis#

plot_vector_field()

Plotting vector field is simple. We just need to evaluate the values of each differential equation.

plot_nullcline()

Nullclines are evaluated through the Brent’s methods. In order to get all \((x, y)\) values that satisfy fx=0 (i.e., \(f_x(x, y) = 0\)), we first fix \(y=y_0\), then apply Brent optimization to get all \(x'\) that satisfy \(f_x(x', y_0) = 0\) (alternatively, we can fix \(x\) then optimize \(y\)). Therefore, we will perform Brent optimization many times, because we will iterate over all \(y\) value according to the resolution setting.

plot_fixed_points()

The fixed point finding in BrainPy relies on BFGS method. First, we define an auxiliary function \(L(x, t)\):

\[ L(x, y) = f_x^2(x, y) + f_y^2(x, y). \]

\(L(x, t)\) is always bigger than 0. We use BFGS optimization to get all local minima. Finally, we filter out the minima whose losses are smaller than \(1e^{-8}\), and we choose them as fixed points.

For this method, how to choose the initial points to perform optimization is the challege, especially when the parameter resolutions are small. Generally, there are four methods provided in BrainPy.

  • fx-nullcline: Choose the points in “fx” nullcline as the initial points for optimization.

  • fy-nullcline: Choose the points in “fy” nullcline as the initial points for optimization.

  • nullclines: Choose both the points in “fx” nullcline and “fy” nullcline as the initial points for optimization.

  • aux_rank: For a given set of parameters, we evaluate loss function at each point according to the resolution setting. Then we choose the first num_rank (default is 100) points which have the smallest losses.

However, if users provide one of functions of x_by_y_in_fx, y_by_x_in_fx, x_by_y_in_fy and y_by_x_in_fy. Things will become very simple, because we can change the 2D system as a 1D system, then we only need to optimzie the fixed points by using our favoriate Brent optimization.

For the given FitzHugh-Nagumo model, we can set

analyzer = bp.analysis.Bifurcation2D(
    model,
    target_vars=dict(V=[-3, 3], w=[-3., 3.]),
    target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
    resolutions={'a': 0.01, 'Iext': 0.01},
    options={bp.analysis.C.y_by_x_in_fy: (lambda V, a=0.7, b=0.8: (V + a) / b)}
)
analyzer.plot_bifurcation()
analyzer.show_figure()
I am making bifurcation analysis ...
I am trying to find fixed points by brentq optimization ...
I am trying to filter out duplicate fixed points ...
	Found 5000 fixed points.
_images/d98912f65cba87bcfdb410adaa031d5bb71a2eff503cece0af90a5414d165be6.png _images/6adf334dc6d029c64ebd0d2ced47c142bc61118f89b94c3c81b9c0baae029c56.png

References#

[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.

[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.

[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.

High-dimensional Analyzers#

@Chaoming Wang

It’s hard to analyze high-dimensional systems. However, we have to analyze high-dimensional systems.

Here, based on numerical optimization methods, BrainPy provides brainpy.analysis.SlowPointFinder to help users find slow points (or fixed points) [1] for your high-dimensional dynamical systems.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

What are slow points?#

For the given system,

\[ \dot{x} = f(x), \]

we wish to find values \(x^∗\) around which the system is approximately linear. Using Taylor series expansion, we have

\[ f(x^* + \delta x) = f(x^*) + f'(x^*)\delta x + 1/2 \delta x f''(x^*) \delta x + \cdots \]

We want the first derivative term (i.e., the linear term) to be dominant, which means \(f(x^*) = 0\) or \(f(x^*) \approx 0\).

  • For \(f(x^*) \approx 0\) which is nonzero but small, we call the point \(x^*\) a slow point.

  • More specially, if \(f(x^*) = 0\), \(x^*\) is a fixed point.

How to find slow points?#

In order to find slow points, we can first define an auxiliary scalar function for your continous system \(\dot{x} = f(x)\),

\[ p(x) = |f(x)|^2. \]

Or, if your system is discrete \(x_n = f(x_{n-1})\), the auxiliary scalar function can be defined as

\[ p(x) = |x - f(x)|^2. \]

If \(x^*\) is a slow point, \(p(x^*) \to 0\).

Then, by minimizing the scalar function \(p(x)\), we can get the candidate points for slow points and for further linearization. For the linear system, it’s stability is evaluated by the eigenvalues of Jacobian matrix.

Here, BrainPy provides brainpy.analysis.SlowPointFinder. It receives a function which f_cell defines \(f(x)\), and f_type which specify the type of the function (it can be “continuous” or “discrete”). Then, brainpy.analysis.SlowPointFinder can help you:

  • optimize to find the fixed/slow points with gradient descent algorithms (find_fps_with_gd_method()) or nonlinear optimization solver (find_fps_with_opt_solver())

  • exclude any fixed points whose losses are above threshold: filter_loss()

  • exclude any non-unique fixed points according to a tolerance: keep_unique()

  • exclude any far-away “outlier” fixed points: exclude_outliers()

  • computing the jacobian matrix for the given fixed/slow points: compute_jacobians()

Example 1: Decision Making Model#

brainpy.analysis.SlowPointFinder is aimed to find slow/fixed points of high-dimensional systems. Of course, it can optimize to find fixed points of low-dimensional systems. We take the 2D decision making system as an example.

# parameters

gamma = 0.641  # Saturation factor for gating variable
tau = 0.06  # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154

JE = 0.3725  # self-coupling strength [nA]
JI = -0.1137  # cross-coupling strength [nA]
JAext = 0.00117  # Stimulus input strength [nA]

mu = 20.  # Stimulus firing rate [spikes/sec]
coh = 0.5  # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
  I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
  r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
  return - s1 / tau + (1. - s1) * gamma * r1

@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
  I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
  r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
  return - s2 / tau + (1. - s2) * gamma * r2

def step(s):
    ds1 = int_s1.f(s[0], 0., s[1])
    ds2 = int_s2.f(s[1], 0., s[0])
    return bm.asarray([ds1.value, ds2.value])

We first use brainpy.analysis.PhasePlane2D to get the standard answer.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    resolutions=0.001,
)
analyzer.plot_fixed_point(select_candidates='aux_rank', with_plot=False)
I am searching fixed points ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 100 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.28276315331459045, s2=0.40635165572166443 is a saddle node.
	#2 s1=0.013946513645350933, s2=0.6573889851570129 is a stable node.
	#3 s1=0.7004519104957581, s2=0.004864314571022987 is a stable node.

Then, let’s check whether the high-dimensional analyzer also works.

finder = bp.analysis.SlowPointFinder(f_cell=step)
finder.find_fps_with_gd_method(
    candidates=bm.random.random((1000, 2)), tolerance=1e-5, num_batch=200,
    opt_setting=dict(method=bm.optimizers.Adam,
                     lr=bm.optimizers.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(1e-5)
finder.keep_unique()
Optimizing with Adam to find fixed points:
    Batches 1-200 in 0.52 sec, Training loss 0.0576312058
    Batches 201-400 in 0.52 sec, Training loss 0.0049517932
    Batches 401-600 in 0.53 sec, Training loss 0.0007580096
    Batches 601-800 in 0.52 sec, Training loss 0.0001687836
    Batches 801-1000 in 0.51 sec, Training loss 0.0000421500
    Batches 1001-1200 in 0.52 sec, Training loss 0.0000108371
    Batches 1201-1400 in 0.52 sec, Training loss 0.0000027990
    Stop optimization as mean training loss 0.0000027990 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
    Kept 962/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
    Kept 3/962 unique fixed points with uniqueness tolerance 0.025.
finder.fixed_points
array([[0.7004518 , 0.00486438],
       [0.28276336, 0.40635186],
       [0.01394662, 0.65738887]], dtype=float32)

Yeah, the fixed points found by brainpy.analysis.PhasePlane2D and brainpy.analysis.SlowPointFinder are nearly the same.

Example 2: Continuous-attractor Neural Network#

Continuous-attractor neural network [2] proposed by Si Wu is a special model which has a line of attractors.

class CANN1D(bp.NeuGroup):
  def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-bm.pi, z_max=bm.pi, name=None):
    super(CANN1D, self).__init__(size=num, name=name)

    # parameters
    self.tau = tau  # The synaptic time constant
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
    self.rho = num / self.z_range  # The neural density
    self.dx = self.z_range / num  # The stimulus density

    # variables
    self.u = bm.Variable(bm.zeros(num))
    self.input = bm.Variable(bm.zeros(num))

    # The connection matrix
    self.conn_mat = self.make_conn(self.x)

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, u, t, Iext):
    r1 = bm.square(u)
    r2 = 1.0 + self.k * bm.sum(r1)
    r = r1 / r2
    Irec = bm.dot(self.conn_mat, r)
    du = (-u + Irec + Iext) / self.tau
    return du

  def dist(self, d):
    d = bm.remainder(d, self.z_range)
    d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
    return d

  def make_conn(self, x):
    assert bm.ndim(x) == 1
    x_left = bm.reshape(x, (-1, 1))
    x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
    d = self.dist(x_left - x_right)
    Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
    return Jxx

  def get_stimulus_by_pos(self, pos):
    return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

  def update(self, _t, _dt):
    self.u[:] = self.integral(self.u, _t, self.input)
    self.input[:] = 0.

  def cell(self, u):
    return self.derivative(u, 0., 0.)
def visualize_fixed_points(fps, plot_ids=(0,), xs=None):
  for i in plot_ids:
    if xs is None:
      plt.plot(fps[i], label=f'FP-{i}')
    else:
      plt.plot(xs, fps[i], label=f'FP-{i}')
  plt.legend()
  plt.show()
cann = CANN1D(num=512, k=0.1, A=30)

These attractors is a series of bumps. Therefore we can initialize our candidate points with noisy bumps.

candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1)))
candidates += bm.random.normal(0., 0.01, candidates.shape)
finder = bp.analysis.SlowPointFinder(f_cell=cann.cell)
finder.find_fps_with_opt_solver(candidates)
finder.filter_loss(1e-6)
finder.keep_unique()
Optimizing to find fixed points:
    Found 629 fixed points from 629 initial points.
Excluding fixed points with squared speed above tolerance 1e-06:
    Kept 357/629 fixed points with tolerance under 1e-06.
Excluding non-unique fixed points:
    Kept 357/357 unique fixed points with uniqueness tolerance 0.025.
pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points)
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()
_images/625980dcef553628ac8bb7d208836a5b5e8a8d27b4953d5b7008576be3234bb6.png
visualize_fixed_points(finder.fixed_points, plot_ids=(10, 20, 30, 40, 50, 60, 70, 80), xs=cann.x)
_images/fb0da8e8febf92362225463b225556dc29a349634f20c7a51d4d15340739d548.png
num = 4
J = finder.compute_jacobians(finder.fixed_points[:num])
for i in range(num):
    eigval, eigvec = np.linalg.eig(np.asarray(J[i]))
    plt.figure()
    plt.scatter(np.real(eigval), np.imag(eigval))
    plt.plot([0, 0], [-1, 1], '--')
    plt.xlabel('Real')
    plt.ylabel('Imaginary')
    plt.show()
_images/e7a645d7b6c4c227d03fc592c0ff2f909c9c0304ab5f5672528d448f58db1c1c.png _images/0fa242a098027cf1bfbb9aa9265d214f0ff42847878b8fb5f2e280c45f4dcaaf.png _images/0fa242a098027cf1bfbb9aa9265d214f0ff42847878b8fb5f2e280c45f4dcaaf.png _images/0fa242a098027cf1bfbb9aa9265d214f0ff42847878b8fb5f2e280c45f4dcaaf.png

More examples of dynamics analysis, for example, analyzing the fixed points in a recurrent neural network, please see BrainPy Examples.

References#

[1] Sussillo, D. , and O. Barak . “Opening the Black Box: Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks.” Neural computation 25.3(2013):626-649.

[2] Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.

Analysis of a Decision-making Model#

@Chaoming Wang

In this section, we are going to use the low-dimensional analyzers to make phase plane and bifurcation analysis for the decision making model proposed by (Wong & Wang) [1].

Decision making model#

This model considers two excitatory neural assemblies, populations 1 and 2 , that compete with each other through a shared pool of inhibitory neurons. In our analysis, we use the following model equations.

Let \(r_1\) and \(r_2\) be firing rates of E and I populations, and the total synaptic input current \(I_i\) and the resulting firing rate \(r_i\) of the neural population \(i\) obey the following input-output relationship (\(F - I\) curve):

\[ r_i = F(I_i) = \frac{aI_i - b}{1-\exp(-d(a I_i - b))} \]

which captures the current-frequency function of a leaky integrate-and-fire neuron. The parameter values are \(a\) = 270 Hz/nA, \(b\) = 108 Hz, \(d\) = 0.154 sec.

Assume that the synaptic drive variables’ \(S_1\) and \(S_2\) obey

\[\begin{split} \frac{dS_1}{dt} = F(I_1)\,\gamma(1-S_1)-S_1/\tau_s\\ \frac{dS_2}{dt} = F(I_2)\,\gamma(1-S_2)-S_2/\tau_s \end{split}\]

where \(\gamma\) = 0.641. The net current into each population is given by

\[\begin{split} I_1 = J_E S_1 + J_I S_2 + I_{b1} + J_{ext}\mu_1 \\ I_2 = J_E S_2 + J_I S_1 +I_{b2} +J_{ext}\mu_2. \end{split}\]

The synaptic time constant is \(\tau_s\) = 100 ms (NMDA time consant). The synaptic coupling strengths are \(J_E\) = 0.2609 nA, \(J_I\) = -0.0497 nA, and \(J_{ext}\) = 0.00052 nA. Stimulus-selective inputs to populations 1 and 2 are governed by unitless parameters \(\mu_1\) and \(\mu_2\), respectively.

For the decision-making paradigm, the input rates \(\mu_1\) and \(\mu_2\) are determined by the stimulus coherence \(c'\) which ranges between 0 (0%) and 1 (100%):

\[\begin{split} \mu_1 =\mu_0(1+c')\\ \mu_2 =\mu_0(1-c') \end{split}\]
import brainpy as bp
import brainpy.math as bm

bp.math.enable_x64()
bp.math.set_platform('cpu')

Parameters#

gamma = 0.641  # Saturation factor for gating variable
tau = 0.1  # Synaptic time constant [sec]
a = 270.  #  Hz/nA
b = 108.  # Hz
d = 0.154  # sec

I0 = 0.3255  # background current [nA]
JE = 0.2609  # self-coupling strength [nA]
JI = -0.0497  # cross-coupling strength [nA]
JAext = 0.00052  # Stimulus input strength [nA]
Ib = 0.3255  # The background input

Model implementation#

@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
    I1 = JE * s1 + JI * s2 + Ib + JAext * mu * (1. + coh)
    r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
    return - s1 / tau + (1. - s1) * gamma * r1

@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
    I2 = JE * s2 + JI * s1 + Ib + JAext * mu * (1. - coh)
    r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
    return - s2 / tau + (1. - s2) * gamma * r2

Phase plane analysis#

The advantage of the reduced model is that we can understand what dynamical behaviors the model generate for a particular parmeter set using phase-plane analysis and the explore how this behavior changed when the model parameters are variaed (bifurcation analysis).

To this end, we will use brainpy.analysis module.

We construct the phase portraits of the reduced model for different stimulus inputs (see Figure 4 and Figure 5 in (Wong & Wang, 2006) [1]).

No stimulus: \(\mu_0 =0\) Hz. In the absence of a stimulus, the two nullclines intersect with each other five times, producing five steady states, of which three are stable (attractors) and two are unstable

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 0.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 5 fixed points.
	#1 s1=0.31384529046352583, s2=0.055785349496809286 is a saddle node.
	#2 s1=0.5669871605297268, s2=0.03189141971571585 is a stable node.
	#3 s1=0.10265144582200994, s2=0.10265095098913903 is a stable node.
	#4 s1=0.05578534267632982, s2=0.3138449310808734 is a saddle node.
	#5 s1=0.03189144636489117, s2=0.5669870352865436 is a stable node.
_images/a8f024c3d3de8f13d555cf36e42ad870a8fa3883e568e8e1c0921514ca78ad8a.png

Symmetric stimulus: \(\mu_0=30\) Hz, \(c'=0\). When a stimulus is applied, the phase space of the model is reconfigured. The spontaneous state vanishes. At the same time, a saddle-type unstable steady state is created that separates the two asymmetrical attractors.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 0.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.658694222824541, s2=0.051807224657553524 is a stable node.
	#2 s1=0.42445578984858473, s2=0.42445562837314144 is a saddle node.
	#3 s1=0.05180717720080611, s2=0.6586942355713473 is a stable node.
_images/b309030431343b2f2d99adea02137976d9adb3a17a5898af6cd17cb62c95ec39.png

Biased stimulus: \(\mu_0=30\) Hz, \(c' = 0.14\) (14 % coherence). The phase space changes when a weak motion stimulus is presented. The phase space is no longer symmetrical: the attractor state s1 (correct choice) has a larger basin of attraction than attractor s2.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 0.14},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.6679776160203832, s2=0.04583015706228367 is a stable node.
	#2 s1=0.38455860789855467, s2=0.45363090352898155 is a saddle node.
	#3 s1=0.059110032802350915, s2=0.6481046659437734 is a stable node.
_images/40f79985839d788d916f1d1a94ef6d503960293d1fdfd11e1696798b40472ac6.png

Stimulus to one population only: \(\mu_0=30\) Hz, \(c'=1.\) (100 % coherence). When \(c'\) is sufficiently large, the saddle steady state annihilates with the less favored attractor, leaving only one choice attractor.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 1.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating vector fields ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 s1=0.7092805209334904, s2=0.0239636630419946 is a stable node.
_images/6a8c07a7938b5ce3e3738a7afbea80ec65631c82b4d76fdd1e48c55354a51d2c.png

Bifurcation analysis#

To see how the ohase portrait of the system changed when we chang the stimulus current, we will generate a bifucation diagram for the reduced model. On the bifurcation diagram the fixed points of the model are shown as a function of a changing parameter.

In the next, we generate bifurcation diagrams with the different parameters.

Fix the coherence \(c'=0\), vary the stimulus strength \(\mu_0\). See Figure 10 in (Wong & Wang, 2006) [1].

analyzer = bp.analysis.Bifurcation2D(
  model=[int_s1, int_s2],
  target_vars={'s1': [0., 1.], 's2': [0., 1.]},
  target_pars={'mu': [-30., 90.]},
  pars_update={'coh': 0.},
  resolutions={'mu': 0.2},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 30000 candidates
I am trying to filter out duplicate fixed points ...
	Found 1744 fixed points.
_images/524b4110b160fc1e071a9b907854edd7a97169e52bc187e4aca3b424aa0c3a52.png _images/1425c54549f7dfec5994c1101871cd8830c09e79dcfd8a68c8c707fbb8e4a936.png

Fix the stimulus strength \(\mu_0 = 30\) Hz, vary the coherence \(c'\).

analyzer = bp.analysis.Bifurcation2D(
  model=[int_s1, int_s2],
  target_vars={'s1': [0., 1.], 's2': [0., 1.]},
  target_pars={'coh': [0., 1.]},
  pars_update={'mu': 30.},
  resolutions={'coh': 0.005},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 10000 candidates
I am trying to filter out duplicate fixed points ...
	Found 475 fixed points.
_images/313efdf2ba974c8be8e23a253bad3c45d5c224a5784ab179a1d869beeb4c593e.png _images/04c53b0e20d99e55edb846e1a9e250388b22768d371934d03af1cf9209c0b93f.png

References#

[1] Wong K-F and Wang X-J (2006). A recurrent network mechanism for time integration in perceptual decisions. J. Neurosci 26, 1314-1328.

Numerical Solvers for Ordinary Differential Equations#

@Chaoming Wang @Xiaoyu Chen

Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to numerical methods for solving such differential equations.

This section will illustrate how to define ordinary differential quations (ODEs) and how to define the numerical integration methods for ODEs in BrainPy.

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

bm.set_platform('cpu')

%matplotlib inline

How to define ODE functions?#

BrainPy provides a convenient and intuitive way to define ODE systems. For the ODEs

\[\begin{split} {dx \over dt} = f_1(x, t, y, p_1)\\ {dy \over dt} = g_1(y, t, x, p_2) \end{split}\]

we can define them in a Python function:

def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

where t denotes the current time, x and y passed before t denote the dynamical variables, and p1 and p2 after t denote the parameters needed in this system. In the function body, the derivative f1 and g1 can be customized by the user’s need. Finally, the corresponding derivatives dx and dy are returned in the same order as that of the variables in the function arguments.

For each variabl, it can be a scalar (var_type = bp.integrators.SCALAR_VAR), a vector/matrix (var_type = bp.integrators.POP_VAR), or a system (var_type = bp.integrators.SYSTEM_VAR). The “system” means that the argument x denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:

def diff(xy, t, p1, p2):
    x, y = xy
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return bm.array([dx, dy])

How to define the numerical integration for ODEs?#

After the definition of ODE functions, it is very easy to define the numerical integration for these functions. We just need to put a decorator bp.odeint above the ODE function.

@bp.odeint
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

After wrapping it by bp.odeint, the function becomes an instance of ODEintegrator.

isinstance(diff, bp.ode.ODEIntegrator)
True

bp.odeint receives several arguments:

  • “method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.

diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x17fe2b8da30>
  • “dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.

diff.dt
0.1
  • “show_code”: bool, to indicate whether to show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.

@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  x_new = x + dx_k1 * dt * 1
  y_new = y + dy_k1 * dt * 1
  return x_new, y_new

{'f': <function diff at 0x0000017FE2BABEE0>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x17fe2bc7700>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  k2_x_arg = x + dt * dx_k1 * 0.5
  k2_y_arg = y + dt * dy_k1 * 0.5
  k2_t_arg = t + dt * 0.5
  dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
  k3_x_arg = x + dt * dx_k2 * 0.5
  k3_y_arg = y + dt * dy_k2 * 0.5
  k3_t_arg = t + dt * 0.5
  dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
  k4_x_arg = x + dt * dx_k3
  k4_y_arg = y + dt * dy_k3
  k4_t_arg = t + dt
  dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
  x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
  y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
  return x_new, y_new

{'f': <function diff at 0x0000017FE2BCA430>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe2bc7e80>

Two Examples#

Example 1: FitzHugh–Nagumo model#

Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:

\[\begin{split} \begin{split} \tau {\dot {w}}&=v+a-bw\\ {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+I_{\rm {ext}} \end{split} \end{split}\]

For this FHN model, we can code it in BrainPy like this:

@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
    dw = (V + a - b * w) / tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,

a = 0.7;   b = 0.8;   tau = 12.5;   Iext = 1.

the solution of the FHN model between 0 and 100 ms can be approximated by

hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
    V, w = integral(V, w, t, Iext, a, b, tau)
    hist_V.append(V)

plt.plot(hist_times, hist_V)
plt.show()
_images/52d96ce30b228c8a4b88d88e5d52a49e8b34aeb17bb3f92959c9ea0fcbb7f584.png

This manual loop in Python code is usually slow. In BrainPy, we provide a structural runner for integrators: brainpy.integrators.IntegratorRunner, which can benefit from the JIT compilation.

runner = bp.integrators.IntegratorRunner(
    integral,
    monitors=['V'],
    inits=dict(V=0., w=0.),
    args=dict(a=a, b=b, tau=tau, Iext=Iext),
    dt=0.01
)
runner.run(100.)

plt.plot(runner.mon.ts, runner.mon.V)
plt.show()
_images/090f3bfcbfa718656a050c620e345d6a0a02b1c3a51c8f61b7dd18b5deb959dc.png

Example 2: Hodgkin–Huxley model#

Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h) are used for modeling the initiation and propagation of the action potential. Specifically, they are governed by the following equations:

\[\begin{split} \begin{aligned} C_{m} \frac{d V}{d t} &=-\bar{g}_{\mathrm{K}} n^{4}\left(V-V_{K}\right)- \bar{g}_{\mathrm{Na}} m^{3} h\left(V-V_{N a}\right)-\bar{g}_{l}\left(V-V_{l}\right)+I_{s y n} \\ \frac{d m}{d t} &=\alpha_{m}(V)(1-m)-\beta_{m}(V) m \\ \frac{d h}{d t} &=\alpha_{h}(V)(1-h)-\beta_{h}(V) h \\ \frac{d n}{d t} &=\alpha_{n}(V)(1-n)-\beta_{n}(V) n \end{aligned} \end{split}\]

In BrainPy, such dynamical system can be coded as:

@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m

    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h

    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n

    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C

    return dVdt, dmdt, dhdt, dndt

Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:

Iext = 10.;   ENa = 50.;   EK = -77.;   EL = -54.387
C = 1.0;      gNa = 120.;  gK = 36.;    gL = 0.03
runner = bp.integrators.IntegratorRunner(
    integral,
    monitors=list('Vmhn'),
    inits=[0., 0., 0., 0.],
    args=dict(Iext=Iext, gNa=gNa, ENa=ENa, gK=gK, EK=EK, gL=gL, EL=EL, C=C),
    dt=0.01
)
runner.run(100.)

plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()
_images/7f57697e26ae1a3f40b139181e9f7c68bbd57956934726c5f0ce394dcb528ffd.png

Provided ODE Numerical Solvers#

BrainPy provides several types of numerical methods for ODEs, including explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler methods.

1. Explicit Runge-Kutta (RK) methods for ODEs#

The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:

Methods

Keywords

Euler

euler

Midpoint

midpoint

Heun’s second-order method

heun2

Ralston’s second-order method

ralston2

RK2

rk2

RK3

rk3

RK4

rk4

Heun’s third-order method

heun3

Ralston’s third-order method

ralston3

Third-order Strong Stability Preserving Runge-Kutta

ssprk3

Ralston’s fourth-order method

ralston4

Runge-Kutta 3/8-rule fourth-order method

rk4_38rule

Users can utilize these methods by specifying the method option in brainpy.odeint() with their corresponding keyword. For example:

@bp.odeint(method='rk4')
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4950d60>

Or, you can directly instance your favorite integrator:

@bp.ode.RK4
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4950670>
def derivative(v, t, p):
    # do something
    return v

int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x17fe4b3af70>

2. Adaptive Runge-Kutta (RK) methods for ODEs#

The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if \(error > tol\), then replace \(dt\) with \(dt_{new}\) and repeat the step. Therefore, adaptive RK methods allow a varied step size. In BrainPy, the following adaptive RK methods are provided in BrainPy:

Methods

keywords

Runge–Kutta–Fehlberg 4(5)

rkf45

Runge–Kutta–Fehlberg 1(2)

rkf12

Dormand–Prince method

rkdp

Cash–Karp method

ck

Bogacki–Shampine method

bs

Heun–Euler method

heun_euler

In default, the above methods are not adaptive, unless users provide a keyword adaptive=True in brainpy.odeint(). When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt will be appended in the functional arguments. Moreover, the tolerance tol for stepsize adjustment can also be modified. Let’s take the Lorenz system as the example:

# adaptively adjust step-size

@bm.jit
@bp.odeint(method='rkf45', 
           adaptive=True, # active the "adaptive" option
           tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
    # should provide one more argument "dt" when using the adaptive rk method
    x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)  
    hist_x.append(x.value)
    hist_y.append(y.value)
    hist_z.append(z.value)
    hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')

fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()
_images/43046aae1fc76bfe29db5f9f4c7e192463de073427b6421fcb750303049cef1e.png _images/4fd7e2fb1ecb541162d914cc77c9a3af236beb402f52ff06d81cb78b2ac273d6.png

3. Exponential Euler methods for ODEs#

Finally, BrainPy provides Exponential integrators for ODEs. For you ODE systems, we highly recommend you to use Exponential Euler methods. Exponential Euler method provided in BrainPy uses automatic differentiation to find linear part.

Methods

keywords

Exponential Euler

exp_euler

Let’s take a linear system as the theoretical demonstration,

\[ {dy \over dt} = A - By \]

the exponential Euler schema is given by:

\[ y(t+dt) = y(t) e^{-B*dt} + {A \over B}(1 - e^{-B*dt}) \]

As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.

However, using Exponential Euler method requires us to write each derivative function separately. Otherwise, the automatic differentiation will lead to wrong results.

Interestingly, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. You will find that by using the Exponential Euler method, the numerical step can be greatly enlarged to save the computation time.

\[\begin{split} \begin{aligned} C_{m}{\frac {d V}{dt}}&= -\left[{\bar {g}}_{\text{K}}n^{4} + {\bar {g}}_{\text{Na}}m^{3}h + {\bar {g}}_{l} \right] V +{\bar {g}}_{\text{K}}n^{4} V_{K} + {\bar {g}}_{\text{Na}}m^{3}h V_{Na} + {\bar {g}}_{l} V_{l} + I_{syn} \\ {\frac {dm}{dt}} &= \left[-\alpha _{m}(V)-\beta _{m}(V)\right]m + \alpha _{m}(V) \\ {\frac {dh}{dt}} &= \left[-\alpha _{h}(V)-\beta _{h}(V)\right]h + \alpha _{h}(V) \\ {\frac {dn}{dt}} &= \left[-\alpha _{n}(V)-\beta _{n}(V)\right]n + \alpha _{n}(V) \\ \end{aligned} \end{split}\]
Iext=10.;   ENa=50.;   EK=-77.;   EL=-54.387
C=1.0;      gNa=120.;  gK=36.;    gL=0.03
def dm(m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return dmdt
def dh(h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return dhdt
def dn(n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return dndt
def dV(V, t, m, h, n, Iext):
    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C
    return dVdt

Although we define HH differential equations as separable functions, relying on brainpy.JointEq, we can numerically integrate these equations jointly.

hh_derivative = bp.JointEq([dV, dm, dh, dn])
def run(method, Iext=10., dt=0.1):
    integral = bp.odeint(hh_derivative, method=method)

    runner = bp.integrators.IntegratorRunner(
        integral,
        monitors=list('Vmhn'),
        inits=[0., 0., 0., 0.],
        args=dict(Iext=Iext),
        dt=dt
    )
    runner.run(100.)

    plt.subplot(211)
    plt.plot(runner.mon.ts, runner.mon.V, label='V')
    plt.legend()
    plt.subplot(212)
    plt.plot(runner.mon.ts, runner.mon.m, label='m')
    plt.plot(runner.mon.ts, runner.mon.h, label='h')
    plt.plot(runner.mon.ts, runner.mon.n, label='n')
    plt.legend()
    plt.show()

Euler Method: not able to complete the integral when the time step is a bit larger

run('euler', Iext=10, dt=0.02)
_images/de222923360f9d24d38120ecca601c725c8b6f89f84a54d14bcc238704d3a69f.png
run('euler', Iext=10, dt=0.1)
_images/e0679241c36ea0361f4aa18057dbafc4f8dd31fde09246d843079c3d76e7d7de.png

RK4 Method: better than the Euler method, but still requires the times step to be small

run('rk4', Iext=10, dt=0.1)
_images/57d061edd9d84dbac988f76879a833b30e2a9dcf3ab3b2f568b869acfcf79fea.png
run('rk4', Iext=10, dt=0.2)
_images/89c4310055f371026b0de4035ba6611279fd8e23c2ee12a0c172a8cec31429e3.png

Exponential Euler Method: allows larger time step and generates accurate results

run('exp_euler', Iext=10, dt=0.2)
_images/f46591828be3772abc24664a2f6fba87bd22b6edf2be8ee914f2b901a28f292c.png

Numerical Solvers for Stochastic Differential Equations#

@Chaoming Wang

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

import brainpy as bp
import matplotlib.pyplot as plt

%matplotlib inline

How to define SDE functions?#

For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by

\[ \begin{aligned} d X_{t}&=f\left(X_{t}, t, p_1\right) d t+g\left(X_{t}, t, p_2\right) d W_{t} \quad (1) \end{aligned} \]

where \(X_t = X(t)\) is the realization of a stochastic process or random variable, \(f(X_t, t)\) is the drift coefficient, \(g(X_t, t)\) denotes the diffusion coefficient, the stochastic process \(W_t\) is called Wiener process.

For this SDE system, we can define two Python funtions \(f\) and \(g\) to represent it.

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df

Same with the ODE functions, the arguments before \(t\) denotes the random variables, while the arguments defined after \(t\) represents the parameters. For the SDE function with scalar noise, the size of the return data \(dg\) and \(df\) should be the same. For example, \(df \in R^d, dg \in R^d\).

However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:

\[ dX_t=f(X_t)dt+\sum_{\alpha=1}^{m}g_{\alpha }(X_t)dW_t ^{\alpha} \]

For such \(m\)-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of \(dg\) has one more dimension. For example, \(df \in R^{d}, dg \in R^{m \times d}\).

How to define the numerical integration for SDEs?#

Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get

\[ \begin{aligned} X_{t}&=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} \quad (2) \end{aligned} \]

In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “\(\circ\)” to distinct it from the former Ito integral.

\[\begin{split} \begin{aligned} d X_{t} &=f\left(X_{t}, t\right) d t+g\left(X_{t}, t\right) \circ d W_{t} \\ X_{t} &=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) \circ d W_{s} \quad (3) \end{aligned} \end{split}\]

The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as

\[\begin{split} \begin{split} \int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} &=\lim _{h \rightarrow 0} \sum_{k=0}^{m-1} g\left(X_{\tau_{k}}, \tau_{k}\right)\left(W\left(t_{k+1}\right)-W\left(t_{k}\right)\right) \\ \mathrm{where} \quad & h = t_{k+1} - t_{k} \\ & \tau_k = (1-\lambda)t_k +\lambda t_{k+1} \end{split} \end{split}\]
  • In the stochastic integral of the Ito SDE, \(\lambda=0\), thus \(\tau_k=t_k\);

  • In the definition of the Stratonovich integral, \(\lambda=0.5\), thus \(\tau_k=(t_{k+1} + t_{k}) / 2\).

In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type in decorator bp.sdeint. intg_type can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.

Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

Or, it can be expressed as:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, g=g_part, method='milstein')

However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(m, d)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, 
                     g=g_part, 
                     method='milstein', 
                     intg_type=bp.integrators.STRA_SDE, 
                     wiener_type=bp.integrators.VECTOR_WIENER)

Example: Noisy Lorenz system#

Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:

\[\begin{split} \begin{array}{l} \frac{d x}{dt}&=\sigma(y-x) &+ px*\xi_x \\ \frac{d y}{dt}&=x(\rho-z)-y &+ py*\xi_y\\ \frac{d z}{dt}&=x y-\beta z &+ pz*\xi_z \end{array} \end{split}\]
sigma = 10; beta = 8/3; 
rho = 28;   p = 0.1

def lorenz_g(x, y, z, t):
    return p * x, p * y, p * z

def lorenz_f(x, y, z, t):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz

lorenz = bp.sdeint(f=lorenz_f, 
                   g=lorenz_g, 
                   intg_type=bp.integrators.ITO_SDE,
                   wiener_type=bp.integrators.SCALAR_WIENER)

To run this integrator, we use brainpy.integrators.IntegratorRunner, which can JIT compile the model to gain impressive speed.

runner = bp.integrators.IntegratorRunner(
    lorenz,
    monitors=['x', 'y', 'z'],
    inits=[1., 1., 1.],
    dt=0.001
)
runner.run(50.)

fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/ea0d22eacd22b758f77daefdcc1df7bd5ab3f85839f3cd9a351cb19a01a4093c.png

We can also rewrite the above differential equation as a JointEq of separable equations, so that it can be applied to Exponential Euler method.

dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
lorenz_f = bp.JointEq(dx, dy, dz)
lorenz = bp.sdeint(f=lorenz_f,
                   g=lorenz_g,
                   intg_type=bp.integrators.ITO_SDE,
                   wiener_type=bp.integrators.SCALAR_WIENER,
                   method='exp_euler')

runner = bp.integrators.IntegratorRunner(
    lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.], dt=0.001
)
runner.run(50.)

plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
_images/6cf6811e512949089e83322affe4450e757fd512f60d6891148606f5b94991f6.png

Supported SDE Numerical Methods#

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

Methods

Keywords

Ito SDE support

Stratonovich SDE support

Scalar Wiener support

Vector Wiener support

Strong SRK scheme: SRI1W1

srk1w1_scalar

Yes

Yes

Strong SRK scheme: SRI2W1

srk2w1_scalar

Yes

Yes

Strong SRK scheme: KlPl

KlPl_scalar

Yes

Yes

Euler method

euler

Yes

Yes

Yes

Yes

Heun method

heun

Yes

Yes

Yes

Derivative-free Milstein

milstein

Yes

Yes

Yes

Yes

Exponential Euler

exp_euler

Yes

Yes

Yes

Numerical Solvers for Fractional Differential Equations#

@Chaoming Wang

import brainpy as bp
import brainpy.math as bm

import matplotlib.pyplot as plt

Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:

  • Grünwald-Letnikov derivative

  • Caputo fractional derivative

See Fractional calculus - Wikipedia for more details.

Methods for Caputo FDEs#

For a given fractional differential equation

\[ \frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t) \]

where the fractional order \(0<\alpha\le 1\). BrainPy provides two kinds of methods:

  • Euler method - brainpy.fde.CaputoEuler

  • L1 schema integration - brainpy.fde.CaputoL1Schema

brainpy.fde.CaputoEuler#

brainpy.fed.CaputoEuler provides one-step Euler method for integrating Caputo fractional differential equations.

Given a fractional-order Qi chaotic system

\[\begin{split} \left\{\begin{array}{l} D^{\alpha} x_{1}=a\left(x_{1}-x_{2}\right)+x_{2} x_{3} \\ D^{\alpha} x_{2}=c x_{1}-x_{2}-x_{1} x_{3} \\ D^{\alpha} x_{3}=x_{1} x_{2}-b x_{3} \end{array}\right. \end{split}\]

we can solve the equation system by:

a, b, c = 35, 8/3, 80

def qi_system(x, y, z, t):
    dx = -a*x + a*y + y*z
    dy = c*x - y - x*z
    dz = -b*z + x*y
    return dx, dy, dz
dt = 0.001
duration = 50
inits = [0.1, 0.2, 0.3]

# The numerical integration of FDE need to know all
# history information, therefore, we need provide
# the overall simulation time "num_step" to save
# all history values.
integrator = bp.fde.CaputoEuler(qi_system,
                                alpha=0.98,  # fractional order
                                num_step=int(duration/dt),
                                inits=inits)

runner = bp.integrators.IntegratorRunner(integrator,
                                         monitors=list('xyz'),
                                         inits=inits,
                                         dt=dt)
runner.run(duration)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()
_images/f95d9de76c81a6572ee75a67e9a656604aa0da2d36a24467a7ab77b0de14bfbe.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/910e2f23eb5aca7388f5db5e36f5c1ba1e1c26e9f5096802d7acdeea6b9d46b9.png

brainpy.fde.CaputoL1Schema#

brainpy.fed.CaputoL1Schema is another commonly used method to integrate Caputo derivative equations. Let’s try it with a fractional-order Lorenz system, which is given by:

\[\begin{split} \left\{\begin{array}{l} D^{\alpha} x=a\left(y-x\right) \\ D^{\alpha} y= x * (b - z) - y \\ D^{\alpha} z =x * y - c * z \end{array}\right. \end{split}\]
a, b, c = 10, 28, 8 / 3

def lorenz_system(x, y, z, t):
    dx = a * (y - x)
    dy = x * (b - z) - y
    dz = x * y - c * z
    return dx, dy, dz
dt = 0.001
duration = 50
inits = [1, 2, 3]

integrator = bp.fde.CaputoL1Schema(lorenz_system,
                                   alpha=0.99,  # fractional order
                                   num_step=int(duration/dt),
                                   inits=inits)

runner = bp.integrators.IntegratorRunner(integrator,
                                         monitors=list('xyz'),
                                         inits=inits,
                                         dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()
_images/60f7408bffdfc5111c92580e598c05455082d516406e9952411c7b060400efb3.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/bf44ff13321d76b798f263ecc1d1eefbf9983f3eb019737d5feaa378fc106bcb.png

Methods for Grünwald-Letnikov FDEs#

Grünwald-Letnikov FDE is another commonly-used type in neuroscience. Here, we provide a efficient computation method according to the short-memory principle in Grünwald-Letnikov method.

brainpy.fde.GLShortMemory#

brainpy.fde.GLShortMemory is highly efficient, because it does not require infinity memory length for numerical solution. Due to the decay property of the coefficients, brainpy.fde.GLShortMemory implements a limited memory length to reduce the computational time. Specifically, it only relies on the memory window of num_memory length. With the increasing width of memory window, the accuracy of numerical approximation will increase.

Here, we demonstrate it by using a fractional-order Chua system, which is defined as

\[\begin{split} \left\{\begin{array}{l} D^{\alpha_{1}} x=a\{y- (1+m_1) x-0.5*(m_0-m_1)*(|x+1|-|x-1|)\} \\ D^{\alpha_{2}} y=x-y+z \\ D^{\alpha_{3}} z=-b y-c z \end{array}\right. \end{split}\]
a, b, c = 10.725, 10.593, 0.268
m0, m1 = -1.1726, -0.7872

def chua_system(x, y, z, t):
    f = m1*x+0.5*(m0-m1)*(abs(x+1)-abs(x-1))
    dx = a*(y-x-f)
    dy = x - y + z
    dz = -b*y - c*z
    return dx, dy, dz
dt = 0.001
duration = 200
inits = [0.2, -0.1, 0.1]

integrator = bp.fde.GLShortMemory(chua_system,
                                  alpha=[0.93, 0.99, 0.92],
                                  num_step=1000,
                                  inits=inits)

runner = bp.integrators.IntegratorRunner(integrator,
                                         monitors=list('xyz'),
                                         inits=inits,
                                         dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/e540ca614b30533afb874d0004f52bbd136e804cb47055353a9e82720bf922ea.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.y, runner.mon.z)
plt.show()
_images/d4f35745b711c56911206cbaceb476ce92a7a115d370a2c32151423c39137d0e.png

Actually, the coefficient used in brainpy.fde.GLWithMemory can be inspected through:

plt.figure(figsize=(10, 6))
coef = integrator.binomial_coef
alphas = bm.as_numpy(integrator.alpha)

plt.subplot(211)
for i in range(3):
    plt.plot(coef[:, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.subplot(212)
for i in range(3):
    plt.plot(coef[:10, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.show()
_images/fe2131e6d2e8caad99147e218dfb546631034522edc37cf8ede2bb968afabd67.png

As you see, the coefficients decay very quickly!

Further reading#

More examples of how to use numerical solvers of fractional differential equations defined in BrainPy, please see:

Numerical Solvers for Delay Differential Equations#

@Chaoming Wang

In real world systems, delay is very often encountered in many practical systems, such as automatic control, biology, economics and long transmission lines. The delayed differential equation (DDEs) is used to describe these dynamical systems.

Delay differential equations (DDEs) are a type of differential equation in which the derivative at a certain time is given in terms of the values of the function at previous times.

Let’s take delay ODEs as the example. The simplest constant delay equations have the form

\[ y'(t) = f(t, y(t), y(t-\tau_1), y(t-\tau_2),\ldots, y(t-\tau_k)) \]

where the time delays (lags) \(\tau_j\) are positive constants.

For neutral type DDE delays appear in derivative terms,

\[ y'(t) = f(t, y(t), y'(t-\tau_1), y'(t-\tau_2),\ldots, y'(t-\tau_k)) \]

More generally, state dependent delays may depend on the solution, that is \(\tau_i = \tau_i (t,y(t))\).

In BrainPy, we support delay differential equations based on delay variables. Specifically, for state-dependent delays, we have:

  • brainpy.math.TimeDelay

  • brainpy.math.LengthDelay

For neutral-type delays, we use:

  • brainpy.math.NeuTimeDelay

  • brainpy.math.NeuLenDelay

import brainpy as bp
import brainpy.math as bm

import matplotlib.pyplot as plt

Delay variables#

For an ODE system, the numerical methods need to know its initial condition \(y(t_0)=y_0\) and its derivative rule \(y'(t_0) = y'_0\). However, for DDEs, it is not enough to give a set of initial values for the function and its derivatives at \(t_0\), but one must give a set of functions to provide the historical values for \(t_0 - max(\tau) \leq t \leq t_0\).

Therefore, you need some delay variables to wrap the variable delays. brainpy.math.TimeDelay can be used to define delay variables which depend on states, and brainpy.math.NeuTimeDelay is used to define delay variables which depend on the derivative.

d = bm.TimeDelay(bm.zeros(2), delay_len=10, dt=1, t0=0, before_t0=lambda t: t)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
d(0.)
DeviceArray([0., 0.], dtype=float32)
d(-0.5)
DeviceArray([-0.5, -0.5], dtype=float32)

Request a time beyond \((max\_delay, t_0)\) will cause an error.

try:
    d(0.1)
except Exception as e:
    print(e)
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x0000025616D24BE0> threw exception 
!!! Error in TimeDelay: 
The request time should be less than the current time 0. But we got 0.10000000149011612 > 0.
!!! Error in TimeDelay: 
The request time should be less than the current time 0. But we got 0.10000000149011612 > 0

Delay ODEs#

Here we illustrate how to make numerical integration of delay ODEs with several examples. Before that, we define a general function to simulate a delay ODE function.

def delay_odeint(duration, eq, args=None, inits=None,
                 state_delays=None, neutral_delays=None,
                 monitors=('x',), method='euler', dt=0.1):
    # define integrators of ODEs based on `brainpy.odeint`
    dde = bp.odeint(eq,
                    state_delays=state_delays,
                    neutral_delays=neutral_delays,
                    method=method)
    # define IntegratorRunner
    runner = bp.integrators.IntegratorRunner(dde,
                                             args=args,
                                             monitors=monitors,
                                             dt=dt,
                                             inits=inits)
    runner.run(duration)
    return runner.mon

Example #1: First-order DDE with one constant delay and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t-1) \]

where the delay is 1 s. the example compares the solutions of three different cases using three different constant history functions:

  • Case #1: \(\phi(t)=-1\)

  • Case #2: \(\phi(t)=0\)

  • Cas3 #3: \(\phi(t)=1\)

def equation(x, t, xdelay):
    return -xdelay(t-1)

case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round')
case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=0., interp_method='round')
case3_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=1., interp_method='round')
case1 = delay_odeint(20., equation,  args={'xdelay': case1_delay},
                     state_delays={'x': case1_delay}) # delay for variable "x"
case2 = delay_odeint(20., equation, args={'xdelay': case2_delay}, state_delays={'x': case2_delay})
case3 = delay_odeint(20., equation, args={'xdelay': case3_delay}, state_delays={'x': case3_delay})
fig, axs = plt.subplots(3, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1)$")

axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')
 
axs[1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=0$')

axs[2].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[2].set_title('$ihf(t)=1$')

plt.show()
_images/9e1f1ab4cf06d563244d281a92a9c418667f4f3a9a1b002800d2ec55bae18bf6.png

Example #2: First-order DDE with one constant delay and a non constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t-2) \]

where the delay is 2 s; the example compares the solutions of four different cases using two different non constant history functions and two different intervals of \(t\):

  • Case #1: \(\phi(t)=e^{-t} - 1, t \in [0, 4]\)

  • Case #2: \(\phi(t)=e^{t} - 1, t \in [0, 4]\)

  • Case #3: \(\phi(t)=e^{-t} - 1, t \in [0, 60]\)

  • Case #4: \(\phi(t)=e^{t} - 1, t \in [0, 60]\)

def eq(x, t, xdelay): 
    return -xdelay(t-2)

delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')
delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t)-1, dt=0.01, interp_method='round')
delay3 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')
delay4 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t)-1, dt=0.01, interp_method='round')
case1 = delay_odeint(4., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(4., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
case3 = delay_odeint(60., eq, args={'xdelay': delay3}, state_delays={'x': delay3}, dt=0.01)
case4 = delay_odeint(60., eq, args={'xdelay': delay4}, state_delays={'x': delay4}, dt=0.01)
fig, axs = plt.subplots(2, 2)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-2)$")

axs[0, 0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 4]$')

axs[0, 1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[0, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 4]$')

axs[1, 0].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[1, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 60]$')

axs[1, 1].plot(case4.ts, case4.x, color='red', linewidth=1)
axs[1, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 60]$')

plt.show()
_images/645e839181c75083bb3cbd77d10cfa9d3596b4253a494956816d613adfe9580d.png

Example #3: First-order DDE with two constant delays and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t - 1) + 0.3 y(t - 2) \]

where the delays are two and are both constants equal to 1s and 2s respectively; The initial historical function is also constant and is \(\phi(t)=1\).

def eq(x, t): 
    return -delay(t-1) + 0.3*delay(t-2)

delay = bm.TimeDelay(bm.ones(1), 2., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(10., eq, inits=[1.], state_delays={'x': delay}, dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1) + 0.3\ y(t-2)$")

axs.plot(mon.ts, mon.x, color='red', linewidth=1)
axs.set_title('$ihf(t)=1$')

plt.show()
_images/077a1f4721ab421d4e72f18b1c32dba8f416f9f69cfbcad70d9ff8a02ee3221e.png

Example #4: System of two first-order DDEs with one constant delay and two constant initial history functions#

Let the following system of DDEs be given:

\[\begin{split} \begin{cases} y_1'(t) = y_1(t) y_2(t-0.5) \\ y_2'(t) = y_2(t) y_1(t-0.5) \end{cases} \end{split}\]

where the delay is only one, constant and equal to 0.5 s and the initial historical functions are also constant; for what we said at the beginning of the post these must be two, in fact being the order of the system of first degree you need one for each unknown and they are: \(y_1(t)=1, y_2(t)=-1\).

def eq(x, y, t):
    dx = x * ydelay(t-0.5)
    dy = y * xdelay(t-0.5)
    return dx, dy

xdelay = bm.TimeDelay(bm.ones(1), 0.5, before_t0=1., dt=0.01, interp_method='round')
ydelay = bm.TimeDelay(-bm.ones(1), 0.5, before_t0=-1., dt=0.01, interp_method='round')

mon = delay_odeint(3., eq, inits=[1., -1], state_delays={'x': xdelay, 'y': ydelay},
             dt=0.01, monitors=['x', 'y'])
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$x'(t)=x(t) y(t-d); y'(t)=y(t) x(t-d)$")

axs.plot(mon.ts, mon.x.flatten(), color='red', linewidth=1)
axs.plot(mon.ts, mon.y.flatten(), color='blue', linewidth=1)
axs.set_title('$ihf_x(t)=1; ihf_y(t)=-1; d=0.5$')

plt.show()
_images/d51484bb86872a9f864ebde4fcb78aee9f9225a499d9b7fd99a4949e3974067d.png

Example #5: Second-order DDE with one constant delay and two constant initial history functions#

Let the following DDE be given:

\[ y(t)'' = -y'(t) - 2y(t) - 0.5 y(t-1) \]

where the delay is only one, constant and equal to 1 s. Since the DDE is second order, in that the second derivative of the unknown function appears, the historical functions must be two, one to give the values of the unknown \(y(t)\) for \(t <= 0\), and one and one to provide the value of the first derivative \(y'(t)\) also for \(t <= 0\).

In this example they are the following two constant functions: \(y(t)=1, y'(t)=0\).

Due to the properties of the second-order equations, the given DDE is equivalent to the following system of first-order equations:

\[\begin{split} \begin{cases} y_1'(t) = y_2(t) \\ y_2'(t) = -y_1'(t) - 2y_1(t) - 0.5 y_1(t-1) \end{cases} \end{split}\]

and so the implementation falls into the case of the previous example of systems of first-order equations.

def eq(x, y, t):
    dx = y
    dy = -y - 2*x - 0.5*xdelay(t-1)
    return dx, dy

xdelay = bm.TimeDelay(bm.ones(1), 1., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(16., eq, inits=[1., 0.], state_delays={'x': xdelay}, monitors=['x', 'y'], dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y''(t)=-y'(t) - 2 y(t) - 0.5 y(t-1)$")
axs.plot(mon.ts, mon.x[:,0], color='red', linewidth=1)
axs.plot(mon.ts, mon.y[:,0], color='green', linewidth=1)
axs.set_title('$ih \, f_y(t)=1; ihf\,dy/dt(t)=0$')

plt.show()
_images/9017dbb659f2bd767df450321138eec86f00f5f875b3cce2a91abd5293c86d57.png

Example #6: First-order DDE with one non constant delay and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=y(t-\mathrm{delay}(y, t)) \]

where the delay is not constant and is given by the function \(\mathrm{delay}(y, t)=|\frac{1}{10} t y(\frac{1}{10} t)|\), the example compares the solutions of two different cases using two different constant history functions:

  • Case #1: \(\phi(t)=-1\)

  • Case #2: \(\phi(t)=1\)

def eq(x, t, xdelay):
    delay = abs(t*xdelay(t - 0.9 * t)/10) # a tensor with (1,)
    delay = delay[0]
    return xdelay(t-delay)

Note

Note here we do not kwon the maximum lenght of the delay. Therefore, we can declare a fixed length delay variable with the delay_len equal to or even bigger than the running duration.

delay1 = bm.TimeDelay(bm.ones(1), 30., before_t0=-1, dt=0.01)
delay2 = bm.TimeDelay(-bm.ones(1), 30., before_t0=1, dt=0.01)
case1 = delay_odeint(30., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(30., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
fig, axs = plt.subplots(2, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=y(t-delay(y, t))$")

axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')

axs[1].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=1$')

plt.show()
_images/71d7d104667867186b4f311fb4e25d7bda74df5634319550eb2dff6581f6f648.png

Delay SDEs#

Save as delay ODEs, state-dependent delay variables can be appended into state_delay argument in brainpy.sdeint function.

delay = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t)-1, dt=0.01, interp_method='round')

f = lambda x, t, xdelay: -xdelay(t-2)
g = lambda x, t, *args: 0.01

dt = 0.01
integral = bp.sdeint(f, g, state_delays={'x': delay})
runner= bp.integrators.IntegratorRunner(integral,
                                        monitors=['x'],
                                        args={'xdelay': delay},
                                        dt=dt)
runner.run(100.)

plt.plot(runner.mon.ts, runner.mon.x)
plt.show()
_images/fb6ba6d3ed8b6b9cf7fc206c9062286c0cde5f5a12e6276bfc8279598ecf948c.png

Delay FDEs#

Fractional order delayed differential equations as the generalization of the delayed differential equations, provide more freedom when we’re describing these systems, let’s see how we can use BrainPy to accelerate the simulation of fractional order delayed differential equations.

The fractional delayed differential equations has the general form:

\[\begin{split} \begin{gathered} D_{t}^{\alpha} y(t)=f(t, y(t), y(t-\tau)), \quad t \geq \xi \\ y(t)=\phi(t), \quad t \in[\xi-\tau, \xi] \end{gathered} \end{split}\]

Lemmings’ population cycle#

The fractional order version of the four-year life cycle of a population of lemmings is given by

\[\begin{split} \begin{gathered} D_{t}^{\alpha} y(t)=3.5 y(t)\left(1-\frac{y(t-0.74)}{19}\right), \\ y(0)=19.00001 \\ y(t)=19, t<0 \end{gathered} \end{split}\]
dt=0.05
delay = bm.TimeDelay(bm.asarray([19.00001]), 0.74, before_t0=19., dt=dt)
f = lambda y, t: 3.5 * y * (1 - delay(t-0.74)/19)
integral = bp.fde.GLShortMemory(f, alpha=0.97, inits=[19.00001], num_step=500,
                                state_delays={'y': delay})
runner = bp.integrators.IntegratorRunner(integral,
                                         inits=bm.asarray([19.00001]),
                                         monitors=['y'],
                                         fun_monitors={'y(t-0.74)': lambda t, _: delay(t-0.74)},
                                         dyn_vars=delay.vars(),
                                         dt=dt)
runner.run(100.)

plt.plot(runner.mon['y'], runner.mon['y(t-0.74)'])
plt.xlabel('y(t)')
plt.ylabel('y(t-0.74)')
plt.show()
_images/ec412d00b98b18022477c58b83f6888a70eb87fa7b98522612f8cacc745f8101.png

Time delay Chen system#

Time delay Chen system as a famous chaotic system with time delay, has important applications in many fields.

\[\begin{split} \left\{\begin{array}{l} D^{\alpha_{1}} x=a(y(t)-x(t-\tau)) \\ D^{\alpha_{2}} y=(c-a) x(t-\tau)-x(t) z(t)+c y(t) \\ D^{\alpha_{3}} z=x(t) y(t)-b z(t-\tau) \end{array}\right. \end{split}\]
dt = 0.001
tau = 0.009
xdelay = bm.TimeDelay(bm.asarray([0.2]), tau, dt=dt)
zdelay = bm.TimeDelay(bm.asarray([0.5]), tau, dt=dt)

def derivative(x, y, z, t):
    a=35; b=3; c=27
    dx = a*(y-xdelay(t-tau))
    dy = (c-a)*xdelay(t-tau)-x*z+c*y
    dz = x*y-b*zdelay(t-tau)
    return dx, dy, dz

integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.94,
                                inits=[0.2, 0., 0.5],
                                num_step=500,
                                state_delays={'x': xdelay, 'z': zdelay})
runner = bp.integrators.IntegratorRunner(integral,
                                         inits=[0.2, 0., 0.5],
                                         monitors=['x', 'y', 'z'],
                                         dyn_vars=xdelay.vars() + zdelay.vars(),
                                         dt=dt)
runner.run(100.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
plt.show()
_images/6c33949c77cddd6f3fee8c54106fe03f85834357588a8255ee45a5bd110cd7c3.png

Enzyme kinetics#

Let’s see a more complex example of the fractional order version of enzyme kinetics with an inhibitor molecule:

\[\begin{split} \begin{gathered} D_{t}^{\alpha} y_{1}(t)=10.5-\frac{y_{1}(t)}{1+0.0005 y_{4}^{3}(t-4)} \\ D_{t}^{\alpha} y_{2}(t)=\frac{y_{1}(t)}{1+0.0005 y_{4}^{3}(t-4)}-y_{2}(t) \\ D_{t}^{\alpha} y_{3}(t)=y_{2}(t)-y_{3}(t) \\ D_{t}^{\alpha} y_{4}(t)=y_{3}(t)-0.5 y_{4}(t) \\ y(t)=[60,10,10,20], t \leq 0 \end{gathered} \end{split}\]
dt = 0.01
tau = 4.
delay = bm.TimeDelay(bm.asarray([20.]), tau, before_t0=20, dt=dt)

def derivative(a, b, c, d, t):
    da = 10.5-a/(1+ 0.0005 * delay(t-tau)**3)
    db = a/(1+0.0005 * delay(t-tau)**3)-b
    dc = b-c
    dd = c-0.5*d
    return da, db, dc, dd

integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.95,
                                inits=[60, 10, 10, 20],
                                num_step=500,
                                state_delays={'d': delay})
runner = bp.integrators.IntegratorRunner(integral,
                                         inits=[60, 10, 10, 20],
                                         monitors=list('abcd'),
                                         dyn_vars=delay.vars(),
                                         dt=dt)
runner.run(200.)
plt.plot(runner.mon.ts, runner.mon.a, label='a')
plt.plot(runner.mon.ts, runner.mon.b, label='b')
plt.plot(runner.mon.ts, runner.mon.c, label='c')
plt.plot(runner.mon.ts, runner.mon.d, label='d')
plt.legend()
plt.xlabel('Time [ms]')
plt.show()
_images/8b76b1330c060a0dd906cad233a0819301fb4fef66eeddd349772104d59ff377.png

Fractional matrix delayed differential equations#

BrainPy is also capable of solving fractional matrix delayed differential equations:

\[ D_{t_{0}}^{\alpha} \mathbf{x}(t)=\mathbf{A}(t) \mathbf{x}(t)+\mathbf{B}(t) \mathbf{x}(t-\tau)+\mathbf{c}(t) \]

Here \(x(t)\) is vector of states of the system, \(c(t)\) is a known function of disturbance.

We explain the detailed usage by using an example:

\[\begin{split} \mathbf{x}(t)=\left(\begin{array}{l} x_{1}(t) \\ x_{2}(t) \\ x_{3}(t) \\ x_{4}(t) \end{array}\right) \end{split}\]
\[\begin{split} \mathbf{A}=\left(\begin{array}{cccc} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & -2 & 0 & 0 \\ -2 & 0 & 0 & 0 \end{array}\right) \end{split}\]
\[\begin{split} \mathbf{B}=\left(\begin{array}{cccc} 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ -2 & 0 & 0 & 0 \\ 0 & -2 & 0 & 0 \end{array}\right) \end{split}\]

With initial condition:

\[\begin{split} \mathbf{x}_{0}(t)=\left(\begin{array}{c} \sin (t) \cos (t) \\ \sin (t) \cos (t) \\ \cos ^{2}(t)-\sin ^{2}(t) \\ \cos ^{2}(t)-\sin ^{2}(t) \end{array}\right) \end{split}\]
dt = 0.01
tau = 3.1416
f = lambda t: bm.asarray([bm.sin(t)*bm.cos(t),
                          bm.sin(t)*bm.cos(t),
                          bm.cos(t)**2-bm.sin(t)**2,
                          bm.cos(t)**2-bm.sin(t)**2])
delay = bm.TimeDelay(f(0.), tau, before_t0=f, dt=dt)

A = bm.asarray([[0, 0, 1, 0], [0, 0, 0, 1], [0, -2, 0, 0], [-2, 0, 0, 0]])
B = bm.asarray([[0, 0, 0, 0], [0, 0, 0, 0], [-2, 0, 0, 0], [0, -2, 0, 0]])
c = bm.asarray([0, 0, 0, 0])
derivative = lambda x, t: A @ x + B @ delay(t - tau) + c

integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.4,
                                inits=[f(0.)],
                                num_step=500,
                                state_delays={'x': delay})
runner = bp.integrators.IntegratorRunner(integral,
                                         inits=[f(0.)],
                                         monitors=['x'],
                                         dyn_vars=delay.vars(),
                                         dt=dt)
runner.run(100.)
plt.plot(runner.mon.x[:, 0], runner.mon.x[:, 2])
plt.xlabel('x0')
plt.ylabel('x2')
plt.show()
_images/1280505827f3040286fa1198a3aa112d13a8ac6df8dcb169ea89228891b2419b.png

Acknowledgement#

This tutorial is highly inspired from the work of Ettore Messina [1] and of Qingyu Qu [2].

Joint Differential Equations#

@Xiaoyu Chen

In a dynamical system, there may be multiple variables that change dynamically over time. Sometimes these variables are interconnected, and updating one variable requires others as the input. For example, in the widely known Hodgkin–Huxley model, the variables \(V\), \(m\), \(h\), and \(n\) are updated synchronously and interdependently (please refer to Building Neuron Modelsfor details). To achieve higher integral accuracy, it is recommended to use brainpy.JointEq to jointly solving interconnected differential equations.

import brainpy as bp

brainpy.JointEq#

brainpy.JointEq is used to merge individual but interconnected differential equations into a single joint equation. For example, below are the two differential equations of the Izhikevich model:

a, b = 0.02, 0.20
dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
du = lambda u, t, V: a * (b * V - u)

Where updating \(V\) requires \(u\) as the input, and updating \(u\) requires \(V\) as the input. The joint equation can be defined as:

joint_eq = bp.JointEq(eqs=(dV, du))

brainpy.JointEq receives only one argument named eqs, which can be a list or tuple containing multiple differential equations. Then it can be packed into a numarical integrator that solves the equation with a specified method, just as what can be done to any individual differential equation.

itg = bp.odeint(joint_eq, method='rk2')

There are several requirements for defining a joint equation:

  1. Every individual differential equation should follow the format of defining a ODE or SDE funtion in BrainPy. For example, the arguments before t denote the dynamical variables and arguments after t denote the parameters.

  2. The same variable in different equations should have the same name. Different variables should named differently.

Note that brainpy.JointEq supports make nested JointEq, which means the instance of JointEq can be an element to compose a new JointEq.

Why use brainpy.JointEq?#

Users may be confused with the function of brainpy.JointEq, because multiple differential equations can be written in a single function:

def diff(V, u, t, Iext):
    dV = 0.04 * V * V + 5 * V + 140 - u + Iext
    du = a * (b * V - u)
    return dV, du

itg_V_u = bp.odeint(diff, method='rk2')

or simply packed into interators separately:

int_V = bp.odeint(dV, method='rk2')
int_u = bp.odeint(du, method='rk2')

To illusrate the difference between joint and separate differential equations, let’s dive into the differential codes of these two types of equations.

If we make numerical solver for each derivative function, they will be solved independently:

import brainpy as bp
bp.odeint(dV, method='rk2', show_code=True)
def brainpy_itg_of_ode6(V, t, u, Iext, dt=0.1):
  dV_k1 = f(V, t, u, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  return V_new

{'f': <function <lambda> at 0x0000022948DD6A60>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x229660543a0>

As is shown in the output code, the variable \(V\) is integrated twice by the RK2 method. For the second differential value dV_k2, the updated value of \(V\) (k2_V_arg) and original \(u\) are used to calculate the differential value. This will generate a tiny error, since the values of \(V\) and \(u\) are taken at different times.

To eliminate this error, the differential equation of \(V\) and \(u\) should be solved jointly through brainpy.JointEq:

eq = bp.JointEq(eqs=(dV, du))
bp.odeint(eq, method='rk2', show_code=True)
def brainpy_itg_of_ode12_joint_eq(V, u, t, Iext, dt=0.1):
  dV_k1, du_k1 = f(V, u, t, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_u_arg = u + dt * du_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
  return V_new, u_new

{'f': <brainpy.integrators.joint_eq.JointEq object at 0x0000022967EC0C40>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x22967ec0160>

It is shown in this output code that second differential values of \(v\) and \(u\) are calculated by using the updated values (k2_V_arg and k2_u_arg) at the same time. This will result in a more accurate integral.

The figure below compares the simulation results of the Izhikevich model using joint and separate differential equations (\(dt = 0.2 ms\)). It is shown that as the simulation time increases, the integral error becomes greater.

Synaptic Connections#

@Tianqiu Zhang @Xiaoyu Chen

Synaptic connections is an essential part for building a neural dynamic system. BrainPy provides several commonly used connection methods in the brainpy.connect module (which can be accessed by the shortcut bp.conn) that can help users to easily construct many types of synaptic connection, inclulding built-in and self-customized connectors.

An Overview of BrainPy Connectors#

Here we provide an overview of BrainPy connectors.

Base class: bp.conn.Connector#

The base class of connectors is brainpy.connect.Connector. All connectors, built-in or customized, should inherit from the Connector class.

Two subclasses: TwoEndConnector and OneEndConnector#

There are two classes inheriting from the base class bp.conn.Connector:

  • bp.conn.TwoEndConnector: a connector to build synaptic connections between two neuron groups.

  • bp.conn.OneEndConnector: a connector to build synaptic connections within a population of neurons.

Users can click the link of each class above to look through the API documentation.

Connector.__init__()#

All connectors need to be initialized first. For each built-in connector, users need to pass in the corresponding parameters for initialization. For details, please see the specific conector type below.

Connector.__call__()#

After initialization, users should call the connector and pass in parameters depending on specific connection types:

  • TwoEndConnector: It has two input parameters pre_size and post_size, each representing the size of the pre- and post-synaptic neuron group. It will result in a connection matrix with the shape of (pre_num, post_num).

  • OneEndConnector: It has only one parameter pre_size which represent the size of the neuron group. It will result in a connection matrix with the shape of (pre_num, pre_num).

The __call__ function returns the class itself.

Connector.build_conn()#

Users can customize the connection in build_conn() function. Notice there are three connection types users can provide:

Connection Types

Definition

‘mat’

Dense conncetion, including a connection matrix.

‘ij’

Index projection, including a pre-neuron index vector and a post-neuron index vector.

‘csr’

Sparse connection, including a index vector and a indptr vector.

Return type can be either a dict or a tuple. Here are two examples of how to return your connection data:

Example 1:

def build_conn(self):
  ind = np.arange(self.pre_num)
  indptr = np.arange(self.pre_num + 1)

  return dict(csr=(ind, indptr), mat=None, ij=None)

Example 2:

def build_conn(self):
  ind = np.arange(self.pre_num)
  indptr = np.arange(self.pre_num + 1)

  return 'csr', (ind, indptr)

After creating the synaptic connection, users can use the require() method to access some useful properties of the connection.

Connector.require()#

This method returns the connection properties required by users. The connection properties are elaborated in the following sections in detail. Here is a brief summary of the connection properties users can require.

Connection properties

Structure

Definition

conn_mat

2-D array (matrix)

Dense connection matrix

pre_ids

1-D array (vector)

Indices of the pre-synaptic neuron group

post_ids

1-D array (vector)

Indices of the post-synaptic neuron group

pre2post

tuple (vector, vector)

The post-synaptic neuron indices and the corresponding pre-synaptic neuron pointers

post2pre

tuple (vector, vector)

The pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers

pre2syn

tuple (vector, vector)

The synapse indices sorted by pre-synaptic neurons and corresponding pre-synaptic neuron pointers

post2syn

tuple (vector, vector)

The synapse indices sorted by post-synaptic neurons and corresponding post-synaptic neuron pointers

Users can implement this method by following sentence:

pre_ids, post_ids, pre2post, conn_mat = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

Note

Note that this method can return multiple connection properties.

Connection Properties#

There are multiple connection properties that can be required by users.

1. conn_mat#

The matrix-based synaptic connection is one of the most intuitive ways to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat'). Each connection matrix is an array with the shape of \((n_{pre}, n_{post})\):

2. pre_ids and post_ids#

Using vectors to store the connection between neuron groups is a much more efficient way to reduce memory when the connection matrix is sparse. For the connction matrix conn_mat defined above, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids.

In this way, we only need two vectors (pre_ids and post_ids) to store the synaptic connection. syn_id in the figure indicates the indices of each neuron pair, i.e. each synapse.

3. pre2post and post2pre#

Another two synaptic structures are pre2post and post2pre. They establish the mapping between the pre- and post-synaptic neurons.

pre2post is a tuple containing two vectors, one of which is the post-synaptic neuron indices and the other is the corresponding pre-synaptic neuron pointers. For example, the following figure shows the indices of the pre-synaptic neurons and the post-synaptic neurons to which the pre-synaptic neurons project:

To record the connection, firstly the post_ids are concatenated as a single vector call the post-synaptic index vector (indices). Because the post-synaptic neuron indices have been sorted by the pre-synaptic neuron indices, it is sufficient to record only the starting position of each pre-synaptic neuron index. Therefore, the pre-synaptic neuron indices and the end of the last pre-synaptic neuron index together make up the pre-synaptic index pointer vector (indptr), which is illustrated in the figure below.

The post-synaptic neuron indices to which pre-synaptic neuron \(i\) projects can be obtained by array slicing:

indices[indptr[i], indptr[i+1]]

Similarly, post2pre is a 2-element tuple containing the pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers. Taking the connection in the illutration aobve as an example, the post-synaptic neuron indices and the pre-synaptic neuron indices to which the post-synaptic neurons project is shown as:

The pre-synaptic index vector (indices) and the post-synaptic index pointer vector (indptr) are listed below:

When the connection is sparse, pre2post (or post2pre) is a very efficient way to store the connection, since the lengths of the two vectors in the tuple are \(n_{synapse}\) and \(n_{pre}\) (\(n_{post}\)), respectively.

4. pre2syn and post2syn#

The last two properties are pre2syn and post2syn that record pre- and post-synaptic projection, respectively.

For pre2syn, similar to pre2post and post2pre, there is a synapse index vector and a pre-synaptic index pointer vector that refers to the starting position of each pre-synaptic neuron index at the synapse index vector.

Below is the same example identifying the connection by pre-synaptic neuron indices and the synapses belonging to them.

For better understanding, The synapse indices, pre- and post-synaptic neuron indices are shown as below:

The pre-synaptic index pointer vector is computed in the same way as in pre2post:

Similarly, post2syn is a also tuple containing the synapse neuron indices and the corresponding post-synaptic neuron pointers.

The only different from pre2syn is that the synapse indices is (most of the time) originally sorted by pre-synaptic neurons, but when computing post2syn, synapses should be sorted by post-synaptic neuron indices:

The synapse index vector (the first row) and the post-synaptic index pointer vector (the last row) are listed below:

import brainpy as bp

bp.math.set_platform('cpu')
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

Built-in regular connections#

brainpy.connect.One2One#

The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
conn(pre_size=10, post_size=10)
<brainpy.connect.regular_conn.One2One at 0x12d27d9a0>

where pre_size denotes the size of the pre-synaptic neuron group, post_size denotes the size of the post-synaptic neuron group. Note that parameter size can be int, tuple of int or list of int where each element represent each dimension of neuron group.

In One2One connection, particularly, pre_size and post_size must be the same.

Class One2One is inherited from TwoEndConnector. Users can use method require or requires to get specific connection properties.

Here is an example:

size = 5
conn = bp.connect.One2One()(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
pre_ids: JaxArray([0, 1, 2, 3, 4], dtype=uint32)
post_ids: JaxArray([0, 1, 2, 3, 4], dtype=uint32)
pre2post: (JaxArray([0, 1, 2, 3, 4], dtype=uint32), JaxArray([0, 1, 2, 3, 4, 5], dtype=uint32))

brainpy.connect.All2All#

All neurons of the post-synaptic population form connections with all neurons of the pre-synaptic population (dense connectivity). Users can choose whether connect the neurons at the same position (include_self=True or False).

conn = bp.connect.All2All(include_self=False)
conn(pre_size=size, post_size=size)
<brainpy.connect.regular_conn.All2All at 0x12d18ac10>

Class All2All is inherited from TwoEndConnector. Users can use method require or requires to get specific connection properties.

Here is an example:

conn = bp.connect.All2All(include_self=False)(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
print('conn_mat:', res[3])
pre_ids: JaxArray([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],            dtype=uint32)
post_ids: JaxArray([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3],            dtype=uint32)
pre2post: (JaxArray([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3],            dtype=uint32), JaxArray([ 0,  4,  8, 12, 16, 20], dtype=uint32))
conn_mat: JaxArray([[False,  True,  True,  True,  True],
          [ True, False,  True,  True,  True],
          [ True,  True, False,  True,  True],
          [ True,  True,  True, False,  True],
          [ True,  True,  True,  True, False]], dtype=bool)

brainpy.connect.GridFour#

GridFour is the four nearest neighbors connection. Each neuron connect to its nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridFour at 0x12d0356a0>

Class GridFour is inherited from OneEndConnector, therefore there is only one parameter pre_size representing the size of neuron group, which should be two-dimensional geometry.

Here is an example:

size = (4, 4)
conn = bp.connect.GridFour(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')

print('pre_ids', res[0])
pre_ids JaxArray([ 0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  4,  4,  4,  5,  5,
           5,  5,  6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9,
           9, 10, 10, 10, 10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14,
          14, 15, 15], dtype=uint32)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()
_images/ffb1155315300f87a6408947a4c1bd304be8be1264f97798eac7eb512accf7be.png

brainpy.connect.GridEight#

GridEight is eight nearest neighbors connection. Each neuron connect to its nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridEight at 0x143df5c10>

Class GridEight is inherited from GridN, which will be introduced as followed.

Here is an example:

size = (4, 4)
conn = bp.connect.GridEight(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')

print('pre_ids', res[0])
pre_ids JaxArray([ 0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,
           3,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  6,
           6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,
           8,  8,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10,
          10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 13, 13, 13, 13,
          13, 14, 14, 14, 14, 14, 15, 15, 15], dtype=uint32)

Take the central point (id = 4) as an example, its neighbors are all the other point except itself. Therefore, its row in conn_mat has True for all values except itself.

# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()
_images/a07e012f6370540ac2c8acdcc2470022b2b65819f43c4fef5907342cf31aa695.png

brainpy.connect.GridN#

GridN is also a nearest neighbors connection. Each neuron connect to its nearest \((2N+1) \cdot (2N+1)\) neurons (if including itself).

Here are some examples to fully understand GridN. It is slightly different from GridEight: GridEight is equivalent to GridN when N = 1.

  • When N = 1: \(\begin{bmatrix} x & x & x\\ x & I & x\\ x & x & x \end{bmatrix}\)

  • When N = 2: \( \begin{bmatrix} x & x & x & x & x\\ x & x & x & x & x\\ x & x & I & x & x\\ x & x & x & x & x\\ x & x & x & x & x \end{bmatrix} \)

conn = bp.connect.GridN(N=2, include_self=False)
conn(pre_size=size)
<brainpy.connect.regular_conn.GridN at 0x143ebdf70>

Here is an example:

size = (4, 4)
conn = bp.connect.GridN(N=1, include_self=False)(pre_size=size)
res = conn.require('conn_mat')
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res)
nx.draw(G, with_labels=True)
plt.show()
_images/e7ff00553119ad3916a14b86281a73b4d6cd18b766a79e44d2434bfd1d44cd05.png

Built-in random connections#

brainpy.connect.FixedProb#

For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

Class brainpy.connect.FixedProb is inherited from TwoEndConnector, and it receives three settings:

  • prob: Fixed probability for connection with a pre-synaptic neuron for each post-synaptic neuron.

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=134)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[False,  True, False, False],
          [False, False, False, False],
          [ True,  True, False,  True],
          [ True, False,  True, False]], dtype=bool)

brainpy.connect.FixedPreNum#

Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

Class brainpy.connect.FixedPreNum is inherited from TwoEndConnector, and it receives three settings:

  • num: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedPreNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[ True, False, False, False],
          [False,  True, False, False],
          [ True,  True,  True,  True],
          [False, False,  True,  True]], dtype=bool)

brainpy.connect.FixedPostNum#

Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

Class brainpy.connect.FixedPostNum is inherited from TwoEndConnector, and it receives three settings:

  • num: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedPostNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
JaxArray([[ True, False,  True, False],
          [False,  True,  True, False],
          [False, False,  True,  True],
          [False, False,  True,  True]], dtype=bool)

brainpy.connect.GaussianProb#

Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.

Specifically,

\[ p=\exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right) \]

where \((x, y)\) is the position of the pre-synaptic neuron and \((x_c,y_c)\) is the position of the post-synaptic neuron.

For example, in a \(30 \textrm{x} 30\) two-dimensional networks, when \(\beta = \frac{1}{2\sigma^2} = 0.1\), the connection pattern is shown as the follows:

GaussianProb is inherited from OneEndConnector, and it receives four settings:

  • sigma: (float) Width of the Gaussian function.

  • encoding_values: (optional, list, tuple, int, float) The value ranges to encode for neurons at each axis.

  • periodic_boundary : (bool) Whether the neuron encode the value space with the periodic boundary.

  • normalize: (bool) Whether normalize the connection probability.

  • include_self : (bool) Whether create the conn at the same position.

  • seed: (bool) The random seed.

conn = bp.connect.GaussianProb(sigma=2, periodic_boundary=True, normalize=True, include_self=True, seed=21)
conn(pre_size=10)
conn.require('conn_mat')
JaxArray([[ True,  True, False,  True, False, False, False, False,
            True,  True],
          [ True,  True,  True,  True, False, False, False, False,
           False,  True],
          [ True,  True,  True,  True, False, False, False, False,
           False,  True],
          [False,  True,  True,  True,  True, False, False, False,
           False, False],
          [False, False, False,  True,  True,  True, False, False,
           False, False],
          [False, False,  True, False,  True,  True,  True,  True,
           False, False],
          [False, False, False,  True, False,  True,  True,  True,
           False, False],
          [ True, False, False, False, False,  True,  True,  True,
           False,  True],
          [False,  True, False, False, False, False,  True,  True,
            True,  True],
          [ True, False,  True, False, False,  True, False,  True,
            True,  True]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/7b771de25e61f615e708edfe140c94fd6624c412287d32757427bc78e2661982.png

brainpy.connect.SmallWorld#

SmallWorld is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:

\[ L\propto \log N \]

[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.

Currently, SmallWorld only support a one-dimensional network with the ring structure. It receives four settings:

  • num_neighbor: the number of the nearest neighbors to connect.

  • prob: the probability of rewiring each edge.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • include_self: whether allow to connect to itself.

conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False,  True, False, False, False, False,  True, False,
            True, False],
          [ True, False,  True,  True, False,  True, False, False,
           False,  True],
          [False,  True, False,  True,  True, False, False, False,
           False, False],
          [False,  True,  True, False,  True,  True, False, False,
           False, False],
          [False, False,  True,  True, False,  True,  True, False,
           False, False],
          [False,  True, False,  True,  True, False, False,  True,
           False,  True],
          [ True, False, False, False,  True, False, False,  True,
            True, False],
          [False, False, False, False, False,  True,  True, False,
            True,  True],
          [ True, False, False, False, False, False,  True,  True,
           False,  True],
          [False,  True, False, False, False,  True, False,  True,
            True, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/249a56adefa305746261bb697b834cb94af9265cb6809db66ac743ca9b48493b.png

brainpy.connect.ScaleFreeBA#

ScaleFreeBA is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA receives the following settings:

  • m: Number of edges to attach from a new node to existing nodes.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.

conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False, False, False,  True,  True,  True,
           False,  True],
          [False, False, False, False, False,  True,  True,  True,
            True,  True],
          [False, False, False, False, False,  True,  True, False,
           False, False],
          [False, False, False, False, False,  True, False, False,
           False,  True],
          [False, False, False, False, False,  True,  True,  True,
            True, False],
          [ True,  True,  True,  True,  True, False,  True,  True,
            True,  True],
          [ True,  True,  True, False,  True,  True, False,  True,
            True,  True],
          [ True,  True, False, False,  True,  True,  True, False,
            True, False],
          [False,  True, False, False,  True,  True,  True,  True,
           False, False],
          [ True,  True, False,  True, False,  True,  True, False,
           False, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/2eeff374939332b75b11b6303b542811259d9c3e2166302f1d5aaa923fe2cd42.png

brainpy.connect.ScaleFreeBADual#

ScaleFreeBADual is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:

  • p: The probability of attaching \(m_1\) edges (as opposed to \(m_2\) edges).

  • m1 : Number of edges to attach from a new node to existing nodes with probability \(p\).

  • m2: Number of edges to attach from a new node to existing nodes with probability \(1-p\).

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.

conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False, False, False,  True, False,  True,
           False,  True],
          [False, False, False, False, False,  True, False,  True,
            True, False],
          [False, False, False, False, False,  True,  True, False,
            True,  True],
          [False, False, False, False, False,  True, False, False,
           False, False],
          [False, False, False, False, False,  True,  True,  True,
            True, False],
          [ True,  True,  True,  True,  True, False,  True,  True,
            True,  True],
          [False, False,  True, False,  True,  True, False,  True,
            True, False],
          [ True,  True, False, False,  True,  True,  True, False,
           False,  True],
          [False,  True,  True, False,  True,  True,  True, False,
           False,  True],
          [ True, False,  True, False, False,  True, False,  True,
            True, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/b3e11de5315e559f541b8c2f8ccfb2c7e2c3fc6e9bc56d37ad279f3b674c5fdb.png

brainpy.connect.PowerLaw#

PowerLaw is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:

  • m : the number of random edges to add for each new node

  • p : Probability of adding a triangle after adding a random edge

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed : Indicator of random number generation state.

[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.

conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
JaxArray([[False, False, False,  True, False,  True, False,  True,
            True, False],
          [False, False, False,  True,  True, False, False,  True,
           False, False],
          [False, False, False,  True,  True,  True,  True, False,
            True,  True],
          [ True,  True,  True, False,  True,  True,  True, False,
           False,  True],
          [False,  True,  True,  True, False, False, False,  True,
            True, False],
          [ True, False,  True,  True, False, False,  True, False,
           False, False],
          [False, False,  True,  True, False,  True, False, False,
           False,  True],
          [ True,  True, False, False,  True, False, False, False,
           False, False],
          [ True, False,  True, False,  True, False, False, False,
           False, False],
          [False, False,  True,  True, False, False,  True, False,
           False, False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/b55b1e8a1c65b875385db23c38d26f80b84477671d4ade4e91b3cf644c7b0f36.png

Encapsulate your existing connections#

BrainPy also allows users to encapsulate existing connections with convenient class interfaces. Users can provide connection types as:

  • Index projection;

  • Dense matrix;

  • Sparse matrix.

Then users should provide pre_size and post_size information in order to instantiate the connection. In such a way, based on the following connection classes, users can generate any other synaptic structures (such like pre2post, pre2syn, conn_mat, etc.) easily.

bp.conn.IJConn#

Here, let’s take a simple connection as an example. In this example, we create a connection which receives users’ handful index projection by using bp.conn.IJConn.

pre_list = np.array([0, 1, 2])
post_list = np.array(([0, 0, 0]))
conn = bp.conn.IJConn(i=pre_list, j=post_list)
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
JaxArray([[ True, False, False],
          [ True, False, False],
          [ True, False, False],
          [False, False, False],
          [False, False, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 0, 0], dtype=uint32), JaxArray([0, 1, 2, 3, 3, 3], dtype=uint32))
conn.requires('pre2syn')
(JaxArray([0, 1, 2], dtype=uint32), JaxArray([0, 1, 2, 3, 3, 3], dtype=uint32))

bp.conn.MatConn#

In next example, we create a connection which receives user’s handful dense connection matrix by using bp.conn.MatConn.

bp.math.random.seed(123)
conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_))
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
JaxArray([[ True,  True, False],
          [False,  True,  True],
          [ True, False,  True],
          [False, False,  True],
          [ True, False, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 1, 1, 2, 0, 2, 2, 0], dtype=uint32),
 JaxArray([0, 2, 4, 6, 7, 8], dtype=uint32))
conn.require('pre2syn')
(JaxArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=uint32),
 JaxArray([0, 2, 4, 6, 7, 8], dtype=uint32))

bp.conn.SparseMatConn#

In last example, we create a connection which receives user’s handful sparse connection matrix by using bp.conn.sparseMatConn

from scipy.sparse import csr_matrix

conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
sparse_mat = csr_matrix(conn_mat)
conn = bp.conn.SparseMatConn(sparse_mat)
conn = conn(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1])
conn.requires('conn_mat')
JaxArray([[ True, False,  True],
          [ True, False,  True],
          [ True, False, False],
          [False,  True,  True],
          [False,  True, False]], dtype=bool)
conn.requires('pre2post')
(JaxArray([0, 2, 0, 2, 0, 1, 2, 1], dtype=uint32),
 JaxArray([0, 2, 4, 5, 7, 8], dtype=uint32))
conn.requires('post2syn')
(JaxArray([0, 2, 4, 5, 7, 1, 3, 6], dtype=uint32),
 JaxArray([0, 3, 5, 8], dtype=uint32))

Using NetworkX to provide connections and pass into Connector#

NetworkX is a Python package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks.

Users can design their own complex netork by using NetworkX.

import networkx as nx
G = nx.Graph()

By definition, a Graph is a collection of nodes (vertices) along with identified pairs of nodes (called edges, links, etc).

To learn more about NetowrkX, please check the official documentation: NetworkX tutorial

Using class brainpy.connect.MatConn to construct connections is recommended here.

  • Dense adjacency matrix: a two-dimensional ndarray.

Here gives an example to illustrate how to transform a random graph into your synaptic connections by using dense adjacency matrix.

G = nx.fast_gnp_random_graph(5, 0.5) # initialize a random graph G
B = nx.adjacency_matrix(G)
A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G

print('dense adjacency matrix:')
print(A)
nx.draw(G, with_labels=True)
plt.show()
dense adjacency matrix:
[[0 0 1 1 0]
 [0 0 0 1 1]
 [1 0 0 0 1]
 [1 1 0 0 1]
 [0 1 1 1 0]]
_images/a108e6b4ca837c175c7d7b9dd965de7b8349fbaa1f9f51a103a0ef665a6e4082.png

Users can use class MatConn inherited from TwoEndConnector to construct connections. A dense adjacency matrix should be passed in when initializing MatConn class. Note that when calling the instance of the class, users should pass in two parameters: pre_size and post_size. In this case, users can use the shape of dense adjacency matrix as the parameters.

conn = bp.connect.MatConn(A)(pre_size=A.shape[0], post_size=A.shape[1])
res = conn.require('conn_mat')

print(res)
JaxArray([[False, False,  True,  True, False],
          [False, False, False,  True,  True],
          [ True, False, False, False,  True],
          [ True,  True, False, False,  True],
          [False,  True,  True,  True, False]], dtype=bool)

Customize your connections#

BrainPy allows users to customize their connections. The following requirements should be satisfied:

  • Your connection class should inherit from brainpy.connect.TwoEndConnector or brainpy.connect.OneEndConnector.

  • __init__ function should be implemented and essential parameters should be initialized.

  • Users should also overwrite build_conn() function to describe how to build your connection.

Let’s take an example to illustrate the details of customization.

class FixedProb(bp.connect.TwoEndConnector):
  """Connect the post-synaptic neurons with fixed probability.

  Parameters
  ----------
  prob : float
      The conn probability.
  include_self : bool
      Whether to create (i, i) connection.
  seed : optional, int
      Seed the random generator.
  """

  def __init__(self, prob, include_self=True, seed=None):
    super(FixedProb, self).__init__()
    assert 0. <= prob <= 1.
    self.prob = prob
    self.include_self = include_self
    self.seed = seed
    self.rng = np.random.RandomState(seed=seed)

  def build_conn(self):
    ind = []
    count = np.zeros(self.pre_num, dtype=np.uint32)

    def _random_prob_conn(rng, pre_i, num_post, prob, include_self):
      p = rng.random(num_post) <= prob
      if (not include_self) and pre_i < num_post:
        p[pre_i] = False
      conn_j = np.asarray(np.where(p)[0], dtype=np.uint32)
      return conn_j

    for i in range(self.pre_num):
      posts = _random_prob_conn(self.rng, pre_i=i, num_post=self.post_num,
                                prob=self.prob, include_self=self.include_self)
      ind.append(posts)
      count[i] = len(posts)

    ind = np.concatenate(ind)
    indptr = np.concatenate(([0], count)).cumsum()

    return 'csr', (ind, indptr)

Then users can initialize the your own connections as below:

conn = FixedProb(prob=0.5, include_self=True)(pre_size=5, post_size=5)
conn.require('conn_mat')
JaxArray([[False,  True, False,  True, False],
          [False, False,  True, False,  True],
          [ True, False,  True, False,  True],
          [False, False, False, False,  True],
          [ True,  True, False, False,  True]], dtype=bool)

Synaptic Weights#

@Xiaoyu Chen

In a brain model, synaptic weights, the strength of the connection between presynaptic and postsynaptic neurons, are crucial to the dynamics of the model. In this section, we will illutrate how to build synaptic weights in a synapse model.

import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt

bp.math.set_platform('cpu')

Creating Static Weights#

Some computational models focus on the network structure and its influence on network dynamics, thus not modeling neural plasticity for simplicity. In this condition, synaptic weights are fixed and do not change in simulation. They can be stored as a scalar, a matrix or a vector depending on the connection strength and density.

1. Storing weights with a scalar#

If all synaptic weights are designed to be the same, the single weight value can be stored as a scalar in the synpase model to save memory space.

weight = 1.

The weight can be stored in a synapse model. When updating the synapse, this weight is assigned to all synapses by scalar multiplication.

2. Storing weights with a matrix#

When the synaptic connection is dense and the synapses are assigned with different weights, weights can be stored in a matrix \(W\), where \(W(i, j)\) refers to the weight of presynaptic neuron \(i\) to postsynaptic neuron \(j\).

BrainPy provides brainpy.initialize.Initializer (or brainpy.init for short) for weight initialization as a matrix. The tutorial of brainpy.init.Initializer is introduced later.

For example, a weight matrix can be constructed using brainpy.init.Uniform, which initializes weights with a random distribution:

pre_size = (4, 4)
post_size = (3, 3)

uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init((pre_size, post_size))
print('shape of weights: {}'.format(weights.shape))
shape of weights: (16, 9)

Then, the weights can be assigned to a group of connections with the same shape. For example, an all-to-all connection matrix can be obtained by brainpy.conn.All2All() whose tutorial is contained in Synaptic Connections:

conn = bp.conn.All2All()
conn(pre_size, post_size)
conn_mat = conn.requires('conn_mat')  # request the connection matrix

Therefore, weights[i, j] refers to the weight of connection (i, j).

i, j = (2, 3)
print('whether (i, j) is connected: {}'.format(conn_mat[i, j]))
print('synaptic weights of (i, j): {}'.format(weights[i, j]))
whether (i, j) is connected: True
synaptic weights of (i, j): 0.48069751262664795

3. Storing weights with a vector#

When the synaptic connection is sparse, using a matrix to store synaptic weights is too wasteful. Instead, the weights can be stored in a vector which has the same length as the synaptic connections.

Weights can be assigned to the corresponding synapses as long as the they are aligned with each other.

size = 5

conn = bp.conn.One2One()
conn(size, size)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')

print('presynaptic neuron ids: {}'.format(pre_ids))
print('postsynaptic neuron ids: {}'.format(post_ids))
print('synapse ids: {}'.format(bm.arange(size)))
presynaptic neuron ids: [0 1 2 3 4]
postsynaptic neuron ids: [0 1 2 3 4]
synapse ids: [0 1 2 3 4]

The weight vector is aligned with the synapse vector, i.e. synapse ids :

weights = bm.random.uniform(0, 2, size)

for i in range(size):
    print('weight of synapse {}: {}'.format(i, weights[i]))
weight of synapse 0: 1.1737329959869385
weight of synapse 1: 1.2609302997589111
weight of synapse 2: 1.4217698574066162
weight of synapse 3: 1.781663179397583
weight of synapse 4: 1.1460866928100586
Conversion from a weight matrix to a weight vector#

For users who would like to obtain the weight vector from the weight matrix, they can first build a connection according to the non-zero elements in the weight matrix and then slice the weight matrix according to the connection:

weight_mat = np.array([[1., 1.5, 0., 0.5], [0., 2.5, 0., 0.], [2., 0., 3, 0.]])
print('weight matrix: \n{}'.format(weight_mat))

conn = bp.conn.MatConn(weight_mat)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')

weight_vec = weight_mat[pre_ids, post_ids]
print('weight_vector: \n{}'.format(weight_vec))
weight matrix: 
[[1.  1.5 0.  0.5]
 [0.  2.5 0.  0. ]
 [2.  0.  3.  0. ]]
weight_vector: 
[1.  1.5 0.5 2.5 2.  3. ]

Note

However, it is not recommended to use this function when the connection is sparse and of a large scale, because generating the weight matrix will take up too much space.

Creating Dynamic Weights#

Sometimes users may want to realize neural plasticity in a brain model, which requires the synaptic weights to change during simulation. In this condition, weights should be considered as variables, thus defined as brainpy.math.Variable. If it is packed in a synapse model, weight updating should be realized in the update(_t, _dt) function of the synapse model.

weights = bm.Variable(bm.ones(10))
weights
Variable([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

Built-in Weight Initializers#

Base Class: bp.init.Initializer#

The base class of weight initializers are brainpy.initialize.Initializer, which can be accessed by the shortcut bp.init. All initializers, built-in or costumized, should inherit the Initializer class.

Weight initialization is implemented in the __call__ function, so that it can be realized automatically when the initializer is called. The __call__ function has a shape parameter that has different meanings in the following two superclasses and returns a weight matrix.

Superclass 1: bp.init.InterLayerInitializer#

The InterLayerInitializer is an abstract subclass of Initializer. Its subclasses initialize the weights between two fully connected layers. The shape parameter of the __call__ function should be a 2-element tuple \((m, n)\), which refers to the number of presynaptic neurons \(m\) and of postsynaptic neurons \(n\). The output of the __call__ function is a bp.math.ndarray with the shape of \((m, n)\), where the value at \((i, j)\) is the initialized weight of the presynaptic neuron \(i\) to postsynpatic neuron \(j\).

Superclass 2: bp.init.IntraLayerInitializer#

The IntraLayerInitializer is also an abstract subclass of Initializer. Its subclasses initialize the weights within a single layer. The shape parameter of the __call__ function refers to the the structure of the neural population \((n_1, n_2, ..., n_d)\). The __call__ function returns a 2-D bp.math.ndarray with the shape of \((\prod_{k=1}^d n_k, \prod_{k=1}^d n_k)\). In the 2-D array, the value at \((i, j)\) is the initialized weight of neuron \(i\) to neuron \(j\) of the flattened neural sequence.

1. Built-In Regular Initializers#

Regular initializers all belong to InterLayerInitializer and initialize the connection weights between two layers with a regular pattern. There are ZeroInit, OneInit, and Identity initializers in built-in regular initializers. Here we show how to use the OneInit initializer. The remaining two classes are used in a similar way.

# visualization
def mat_visualize(matrix, cmap=plt.cm.get_cmap('coolwarm')):
    im = plt.matshow(matrix, cmap=cmap)
    plt.colorbar(mappable=im, shrink=0.8, aspect=15)
    plt.show()
# 'OneInit' initializes all the weights with the same value
shape = (5, 6)
one_init = bp.init.OneInit(value=2.5)
weights = one_init(shape)
print(weights)
JaxArray([[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
          [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
          [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
          [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
          [2.5, 2.5, 2.5, 2.5, 2.5, 2.5]], dtype=float32)

2. Built-In Random Initializers#

Random initializers all belong to InterLayerInitializer and initialize the connection weights between two layers with a random distribution. There are Normal, Uniform, Orthogonal and other initializers in built-in regular initializers. Here we show how to use the Normal and Uniform initializers.

bp.init.Normal

This initializer initializes the weights with a normal distribution. The variance of the distribution changes according to the scale parameter. In the following example, 10 presynaptic neurons are fully connected to 20 postsynaptic neurons with random weight values:

shape = (10, 20)
normal_init = bp.init.Normal(scale=1.0)
weights = normal_init(shape)
mat_visualize(weights)
_images/59322b503691cb2c40ff7a0eb6a41c356ec8ae379390ac9c5b85e0c4b06b0e7b.png

bp.init.Uniform

This initializer resembles brainpy.init.Normal but initializes the weights with a uniform distribution.

uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init(shape)
mat_visualize(weights)
_images/d23b206851bb84996585333d49f8a7fbbb836b45b9f626f35f8f69627d4292a2.png

3. Built-In Decay Initializers#

Decay initializers all belong to IntraLayerInitializer and initialize the connection weights within a layer with a decay function according to the neural distance. There are GaussianDecay and DOGDecay initializers in built-in decay initializers. Below are examples of how to use them.

brainpy.training.initialize.GaussianDecay

This initializer creates a Gaussian connectivity pattern within a population of neurons, where the weights decay with a gaussian function. Specifically, for any pair of neurons \( (i, j) \), the weight is computed as

\[ w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) \]

where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 5 \times 5 \). Note that this shape is the structure of the target neural population, not the size of presynaptic and postsynaptic neurons.

size = (5, 5)
gaussian_init = bp.init.GaussianDecay(sigma=2., max_w=10., include_self=True)
weights = gaussian_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (25, 25)

Self-connections are created if include_self=True. The connection weights of neuron \(i\) with others are stored in row \(i\) of weights. For instance, the connection weights of neuron(1, 2) to other neurons are stored in weights[7] (\(5 \times 1 +2 = 7\)). After reshaping, the weights are:

mat_visualize(weights[0].reshape(size), cmap=plt.cm.get_cmap('Reds'))
_images/5fe0b48607bd2582389458479e65ede69574ab83ed80ac4a550bc0c4f97005c1.png

brainpy.training.initialize.DOGDecay

This initializer creates a Difference-Of-Gaussian (DOG) connectivity pattern within a population of neurons. Specifically, for the given pair of neurons \( (i, j) \), the weight between them is computed as

\[ w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2}) \]

where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 10 \times 12 \):

size = (10, 12)
dog_init = bp.init.DOGDecay(sigmas=(1., 3.), max_ws=(10., 5.), min_w=0.1, include_self=True)
weights = dog_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (120, 120)

Weights smaller than min_w will not be created. min_w \( = 0.005 \times min( \) max_ws \( ) \) if it is not assigned with a value. The organization of weights is similar to that in the GaussianDecay initializer. For instance, the connection weights of neuron (3, 4) to other neurons after reshaping are shown as below:

mat_visualize(weights[3*12+4].reshape(size), cmap=plt.cm.get_cmap('Reds'))
_images/caff95d04a21986bba8875f504c418559dd8599212b616a152ac58115353baa9.png

Customizing your initializers#

BrainPy also allows users to customize the weight initializers of their own. When customizing a initializer, users should follow the instructions below:

  • Your initializer should inherit brainpy.initialize.Initializer.

  • Override the __call__ funtion, to which the shape parameter should be given.

Here is an example of creating an inter-layer initializer that initialize the weights as follows:

\[ w(i, j) = max(w_{max} - \sigma |v_i - v_j|, 0) \]
class LinearDecay(bp.init.InterLayerInitializer):
    def __init__(self, max_w, sigma=1.):
        self.max_w = max_w
        self.sigma = sigma
    
    def __call__(self, shape, dtype=None):
        mat = bp.math.zeros(shape, dtype=dtype)
        n_pre, n_post = shape
        seq = np.arange(n_pre)
        current_w = self.max_w
        
        for i in range(max(n_pre, n_post)):
            if current_w <= 0:
                break
            seq_plus = ((seq + i) >= 0) & ((seq + i) < n_post)
            seq_minus = ((seq - i) >= 0) & ((seq - i) < n_post)
            mat[seq[seq_plus], (seq + i)[seq_plus]] = current_w
            mat[seq[seq_minus], (seq - i)[seq_minus]] = current_w
            current_w -= self.sigma
        
        return mat
shape = (10, 15)
lin_init = LinearDecay(max_w=5, sigma=1.)
weights = lin_init(shape)
mat_visualize(weights, cmap=plt.cm.get_cmap('Reds'))
_images/752c8a7e5ab6128080f540224fc5da13525d33ce0330b960842cd05233498fbb.png

Note

Note that customized initializers, or brainpy.init.Initializer, is not limited to returning a matrix. Although currently all the built-in initializers use matrix to store weights, they can also be designed to return a vector to store synaptic weights.

Gradient Descent Optimizers#

@Chaoming Wang @Xiaoyu Chen

Gradient descent is one of the most popular optimization methods. At present, gradient descent optimizers, combined with the loss function, are the key to machine learning, especially deep learning. In this section, we are going to understand:

  • how to use optimizers in BrainPy?

  • how to customize your own optimizer?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
import matplotlib.pyplot as plt

Optimizers in BrainPy#

The basic optimizer class in BrainPy is brainpy.optimizers.Optimizer, which inludes the following optimizers:

  • SGD

  • Momentum

  • Nesterov momentum

  • Adagrad

  • Adadelta

  • RMSProp

  • Adam

All supported optimizers can be inspected through the brainpy.math.optimizers APIs.

Generally, an optimizer initialization receives the learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

  • lr can be a float, or an instance of brainpy.optim.Scheduler.

  • train_vars should be a dict of Variable.

Here we launch a SGD optimizer.

a = bm.Variable(bm.ones((5, 4)))
b = bm.Variable(bm.zeros((3, 3)))

op = bp.optim.SGD(lr=0.001, train_vars={'a': a, 'b': b})

When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)
a: Variable([[0.9998637 , 0.99998003, 0.99998164, 0.99945515],
          [0.99987453, 0.9994551 , 0.9992173 , 0.99931335],
          [0.99995905, 0.999395  , 0.9999578 , 0.9992222 ],
          [0.9990663 , 0.9991484 , 0.99950355, 0.9991641 ],
          [0.999546  , 0.99927944, 0.9995042 , 0.9993433 ]],            dtype=float32)
b: Variable([[-2.2797585e-06, -8.2317321e-04, -3.6020565e-04],
          [-4.7648646e-04, -4.1223815e-04, -4.2962815e-04],
          [-6.4605832e-05, -9.0033154e-04, -2.7676427e-04]],            dtype=float32)

You can process the gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}

grads_pre
{'a': JaxArray([[0.34498155, 0.5973849 , 0.3798045 , 0.9029752 ],
           [0.48180568, 0.85238564, 0.9597529 , 0.19037902],
           [0.09431374, 0.03752387, 0.05545044, 0.18625283],
           [0.4265157 , 0.13127708, 0.44978166, 0.6449729 ],
           [0.5855427 , 0.02250564, 0.9523196 , 0.317971  ]],            dtype=float32),
 'b': JaxArray([[0.4190004 , 0.7033491 , 0.13393831],
           [0.11366987, 0.33574808, 0.37153232],
           [0.39718974, 0.25615263, 0.08950627]], dtype=float32)}
grads_post = bm.clip_by_norm(grads_pre, 1.)

grads_post
{'a': JaxArray([[0.14619358, 0.2531551 , 0.16095057, 0.38265574],
           [0.20417583, 0.3612173 , 0.40671656, 0.08067733],
           [0.03996754, 0.01590157, 0.02349835, 0.07892876],
           [0.18074548, 0.05563157, 0.19060494, 0.27332157],
           [0.24813668, 0.00953726, 0.40356654, 0.13474725]],            dtype=float32),
 'b': JaxArray([[0.38518783, 0.6465901 , 0.12312973],
           [0.10449692, 0.30865383, 0.34155035],
           [0.36513725, 0.23548158, 0.08228327]], dtype=float32)}
op.update(grads_post)

print('a:', a)
print('b:', b)
a: Variable([[0.9997175 , 0.9997269 , 0.9998207 , 0.9990725 ],
          [0.9996703 , 0.9990939 , 0.9988105 , 0.99923265],
          [0.99991906, 0.9993791 , 0.9999343 , 0.9991433 ],
          [0.9988856 , 0.9990928 , 0.99931294, 0.99889076],
          [0.99929786, 0.9992699 , 0.9991006 , 0.9992085 ]],            dtype=float32)
b: Variable([[-0.00038747, -0.00146976, -0.00048334],
          [-0.00058098, -0.00072089, -0.00077118],
          [-0.00042974, -0.00113581, -0.00035905]], dtype=float32)

Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, your dyn_vars in bm.jit() should include variables in Optimzier.vars().

op.vars()  # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable([2], dtype=int32)}
bp.optim.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables
{'Constant1.step': Variable([0], dtype=int32),
 'Momentum0.a_v': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Momentum0.b_v': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32)}
bp.optim.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables
{'Constant2.step': Variable([0], dtype=int32),
 'Adam0.a_m': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Adam0.b_m': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32),
 'Adam0.a_v': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Adam0.b_v': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32)}

Creating A Self-Customized Optimizer#

To create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

  • __init__(): init function that receives the learning rate (lr) and trainable variables (train_vars). Do not forget to register your dynamical changed variables into implicit_vars.

  • update(grads): update function that computes the updated parameters.

The general structure is shown below:

class CustomizeOp(bp.optim.Optimizer):
    def __init__(self, lr, train_vars, *params, **other_params):
        super(CustomizeOp, self).__init__(lr, train_vars)
        
        # customize your initialization
        
    def update(self, grads):
        # customize your update logic
        pass

Schedulers#

Scheduler seeks to adjust the learning rate during training through reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

Here we set up an exponential decay scheduler, in which the learning rate will decay exponentially along the training step.

sc = bp.optim.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
    plt.plot(steps, rates)
    plt.xlabel('Train Step')
    plt.ylabel('Learning Rate')
    plt.show()
steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)
_images/d4c6e12211f0bb94f7f05195c0bb03fa1b2d186c95e208f8959d7281d009042a.png

After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr
Constant(0.001)

One can get the current learning rate value by calling Scheduler.__call__(i=None).

  • If i is not provided, the learning rate value will be evaluated at the built-in training step.

  • Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()
0.001

In BrainPy, several commonly used learning rate schedulers are used:

  • Constant

  • ExponentialDecay

  • InverseTimeDecay

  • PolynomialDecay

  • PiecewiseConstant

For more details, please see the brainpy.math.optimizers APIs.

# InverseTimeDecay scheduler

rates = bp.optim.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)
_images/e4e5977b81108bffea5108caed258da31cb0243fdff98d530cea6e070365787c.png
# PolynomialDecay scheduler

rates = bp.optim.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)
_images/6741c56d6b37a41cb910962467af3413fc2dbfe1a4846f3692fc772ef6576d14.png

Creating a Self-Customized Scheduler#

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

  • __init__(): the init function.

  • __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bp.optim.Scheduler):
    def __init__(self, lr, *params, **other_params):
        super(CustomizeScheduler, self).__init__(lr)
        
        # customize your initialization
        
    def __call__(self, i=None):
        # customize your update logic
        pass

Runners#

@Chaoming Wang @Xiaoyu Chen

Runners for Dynamical Systems#

The convenient simulation interfaces for dynamical systems in BrainPy are implemented in brainpy.simulation.runner. Currently, we implement two kinds of runner: DSRunner and ReportRunner. They have their respective advantages.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Initializing a runner#

Generally, we can initialize a runner with the format of:

SomeRunner(target=instance_of_dynamical_system,
           inputs=inputs_for_target_variables,
           monitors=interested_variables_to_monitor,
           dyn_vars=dynamical_changed_variables,
           jit=enable_jit_or_not)

In which

  • target specifies the model to be simulated. It must an instance of brainpy.DynamicalSystem.

  • monitors is used to define target variables in the model. During the simulation, the history values of the monitored variables will be recorded. More information can be found in the Monitors tutorial.

  • inputs is used to define the input operations for specific variables. It will be expanded in the Inputs tutorial.

  • dyn_vars is used to specify all the dynamically changed variables used in the target model.

  • jit determines whether to use JIT compilation during the simulation.

Here we define an E/I balanced network as the simulation model.

class EINet(bp.Network):
  def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = bp.dyn.LIF(num_exc, **pars, method=method)
    I = bp.dyn.LIF(num_inh, **pars, method=method)
    E.V[:] = bm.random.randn(num_exc) * 2 - 55.
    I.V[:] = bm.random.randn(num_inh) * 2 - 55.

    # synapses
    E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
                         E=0., g_max=0.6, tau=5., method=method)
    E2I = bp.dyn.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
                         E=0., g_max=0.6, tau=5., method=method)
    I2E = bp.dyn.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
                         E=-80., g_max=6.7, tau=10., method=method)
    I2I = bp.dyn.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
                         E=-80., g_max=6.7, tau=10., method=method)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)

Then we will wrap it in different runners for dynamic simulation.

brainpy.DSRunner#

brainpy.DSRunner aims to provide model simulation with an outstanding performance. It takes advantage of the structural loop primitive to lower the model onto the XLA devices.

net = EINet()

runner = bp.DSRunner(net, 
                     monitors=['E.spike'],
                     inputs=[('E.input', 20.), ('I.input', 20.)],
                     jit=True)
runner.run(100.)
1.2974658012390137
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
_images/6f97aaf5cdab57ead10dedccee48becfc6ba55a5de967ce2ed4e24473b901eed.png

Note that if the parameter jit is set to True, then all the variables will be JIT compiled and thus the system cannot be debugged by Python debugging tools. For debugging, users can set jit=False.

brainpy.ReportRunner#

brainpy.ReportRunner aims to provide a Pythonic interface for model debugging. Users can use the standard Python debugging tools when simulating the model with ReportRunner.

The drawback of the brainpy.ReportRunner is that it is relatively slow. It iterates the loop along times during the simulation.

net = EINet()

runner = bp.ReportRunner(net, 
                         monitors=['E.spike'],
                         inputs=[('E.input', 20.), ('I.input', 20.)],
                         jit=True)
runner.run(100.)
3.402564764022827

We can see from the output that the time spent for simulation through ReportRunner is longer than that through DSRunner.

bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
_images/f73d1361efada020ffe1f657ae25d7b9ed94bb9abe0c83993c7749f15564543b.png

Runners for Neural Network Training#

Inputs#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about stimulus inputs.

import brainpy as bp
import brainpy.math as bm

Inputs in brainpy.dyn.DSRunner#

In brain dynamics simulation, various inpus are usually given to different units of the dynamical system. In BrainPy, inputs can be specified to runners for dynamical systems. The aim of inputs is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording.

inputs should have the format like (target, value, [type, operation]), where

  • target is the target variable to inject the input.

  • value is the input value. It can be a scalar, a tensor, or a iterable object/function.

  • type is the type of the input value. It support two types of input: fix and iter. The first one means that the data is static; the second one denotes the data can be iterable, no matter whether the input value is a tensor or a function. The iter type must be explicitly stated.

  • operation is the input operation on the target variable. It should be set as one of { + , - , * , / , = }, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated as val = val + input.

Users can also give multiple inputs for different target variables, like:


inputs=[(target1, value1, [type1, op1]),  
        (target2, value2, [type2, op2]),
              ... ]

The mechanism of inputs is the same as monitors. BrainPy finds the target variables for input operations through the absolute or relative path.

Input construction functions#

Like electrophysiological experiments, model simulation also needs various kind of inputs. BrainPy provide several convenient input functions to help users construct input currents.

1. brainpy.inputs.section_input()#

brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input() (see below).

Sometimes, we need input currents with different values in different periods. For example, if you want to get an input that is 0 in the first 100 ms, 1 in the next 300 ms, and 0 again from the last 100 ms, you can define:

current1, duration = bp.inputs.section_input(values=[0, 1., 0.],
                                             durations=[100, 300, 100],
                                             return_length=True,
                                             dt=0.1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Where values receive a list/arrray of the current values in each section and durations receives a list/array of the duration of each section. The function returns a tensor as the current, the length of which is duration\(/\mathrm{d}t\) (if not specified, \(\mathrm{d}t=0.1 \mathrm{ms}\)). We can visualize the current input by:

import numpy as np
import matplotlib.pyplot as plt

def show(current, duration, title):
    ts = np.arange(0, duration, bm.get_dt())
    plt.plot(ts, current)
    plt.title(title)
    plt.xlabel('Time [ms]')
    plt.ylabel('Current Value')
    plt.show()

show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')
_images/03235e8a63ec96fb9788fba27da0ef6ec04f8ad0e7deb14e6a72562e7d42afaf.png

2. brainpy.inputs.constant_input()#

brainpy.inputs.constant_input() function helps users to format constant currents in several periods.

We can generate the above input current with constant_input() by:

current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])

Where each tuple in the list contains the value and duration of the input in this section.

show(current2, duration, '[(0, 100), (1, 300), (0, 100)]')
_images/ff2d03e0c52506fd98becad9b35157df276f77d26fc0d84684b458944ff33a7c.png

3. brainpy.inputs.spike_input()#

brainpy.inputs.spike_input() constructs an input containing a series of short-time spikes. It receives the following settings:

  • sp_times : The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.

  • sp_lens : The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same with sp_times.

  • sp_sizes : The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same with sp_times.

  • duration : The total current duration.

  • dt : The time step precision. The default is None (will be initialized as the default dt step).

For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, where each spike lasts 1 ms and the average value for each spike is 0.5, then you can define the current by:

current3 = bp.inputs.spike_input(
    sp_times=[10, 20, 30, 200, 300],
    sp_lens=1.,  # can be a list to specify the spike length at each point
    sp_sizes=0.5,  # can be a list to specify the spike current size at each point
    duration=400.)

show(current3, 400, 'Spike Input Example')
_images/4c0acb5509536b9855dcca4ab481f2fa3aac78314a02f781a54f9d3df70560cc.png

4. brainpy.inputs.ramp_input()#

brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:

  • c_start : The minimum (or maximum) current size.

  • c_end : The maximum (or minimum) current size.

  • duration : The total duration.

  • t_start : The ramped current start time-point.

  • t_end : The ramped current end time-point. Default is the None.

  • dt : The current precision.

We illustrate the usage of brainpy.inputs.ramp_input() by two examples.

In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (500 ms).

duration = 500
current4 = bp.inputs.ramp_input(0, 1, duration)

show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
                        r'$t_{start}$=0, $t_{end}$=None' % (duration))
_images/7732b97b88f2ba590741a5221b6c8e9bed5b9f975cb11c7c7a597de0e540970c.png

In the second example, we increase the current size from 0. to 1. from the 100 ms to 400 ms.

duration, t_start, t_end = 500, 100, 400
current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)

show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
                        r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))
_images/089ffaab9f16d234441edb51a0ebbf0d6755a05b522f0ad2a17ff4d5717dd0f4.png

5. brainpy.inputs.wiener_process#

brainpy.inputs.wiener_process() is used to generate the basic Wiener process \(dW\), i.e. random numbers drawn from \(N(0, \sqrt{dt})\).

duration = 200
current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.)
show(current6, duration, 'Wiener Process')
_images/67f2178490596f62baa1ed62aafd9c72f633f78ae35e9ef6867be2d3625bc13e.png

6. brainpy.inputs.ou_process#

brainpy.inputs.ou_process() is used to generate the noise time series from Ornstein-Uhlenback process \(\dot{x} = (\mu - x)/\tau \cdot dt + \sigma\cdot dW\).

duration = 200
current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.)
show(current7, duration, 'Ornstein-Uhlenbeck Process')
_images/acb6154897bd9867cce0ecb9ab3f227969e59a1035a1cf77ed469d33d5f79027.png

7. brainpy.inputs.sinusoidal_input#

brainpy.inputs.sinusoidal_input() can help to generate sinusoidal inputs.

duration = 2000
current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration,  t_start=100., )
show(current8, duration, 'Sinusoidal Input')
_images/aebed872bc114708001181641e00c3754ee84ed8567ae3ec23ed8f11b15c23d7.png

8. brainpy.inputs.square_input#

brainpy.inputs.square_input() can help to generate oscillatory square inputs.

duration = 2000
current9 = bp.inputs.square_input(amplitude=1., frequency=2.0,
                                  duration=duration, t_start=100)
show(current9, duration, 'Square Input')
_images/120b2131a740d3e0f3ab35909e6902a3c3330211faefa1822ced0dd697e5ff77.png

More complex inputs#

Because the current input is stored as a tensor, a complex input can be realized by the combination of several simple currents.

show(current1 + current5, 500, 'A Complex Current Input')
_images/a3c81671ae8e4b0908f118d41a11d0771b69fc5e6532d985c258284b616a4eca.png

General properties of input functions#

1. Every input function receives a dt specification.

If dt is not provided, input functions will use the default dt in the whole BrainPy system.

I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1)
I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01)
print('I1.shape: {}'.format(I1.shape))
print('I2.shape: {}'.format(I2.shape))
I1.shape: (600,)
I2.shape: (6000,)

2. All input functions can automatically broadcast the current shapes if they are heterogenous among different periods.

For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example:

current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
                                  durations=[100, 300, 100])

current.shape
(5000, 3, 10)

Monitors#

@Chaoming Wang @Xiaoyu Chen

BrainPy has a systematic naming system. Any model in BrainPy have a unique name. Thus, nodes, integrators, and variables can be easily accessed in a huge network. Based on this naming system, BrainPy provides a set of convenient monitoring supports. In this section, we are going to talk about this.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
bp.math.set_dt(0.02)
import numpy as np
import matplotlib.pyplot as plt

Initializing Monitors in a Runner#

In BrainPy, any instance of brainpy.Runner has a built-in monitor. Users can set up a monitor when initializing a runner.

For example, if we want to simulate a Hodgkin-Hoxley (HH) model and monitor its membrane potential \(V\) and the spikes it generates:

HH = bp.dyn.HH
model  = HH(1)

After defining a HH neuron, we can add monitors while setting up the runner. When specifying the monitors parameter, a monitor, which is an instance of brainpy.Monitor, will be initialized. The first method to initialize a monitor is through a list/tuple of strings:

# set up a monitor using a list of str
runner1 = bp.StructRunner(model, 
                          monitors=['V', 'spike'], 
                          inputs=('input', 10))

type(runner1.mon)
brainpy.running.monitor.Monitor

where the string 'V' and 'spike' corresponds to the name of the variables in the HH model:

model.V, model.spike
(Variable(DeviceArray([0.], dtype=float32)),
 Variable(DeviceArray([False], dtype=bool)))

Besides using a list/tuple of strings, users can also directly use the Monitor class to initialize a monitor:

# set up a monitor using brainpy.Monitor
runner2 = bp.StructRunner(model, monitors=bp.Monitor(variables=['V', 'spike']))

Once we call the runner with a given time duration, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using .mon.[variable_name]. The default history times .mon.ts will also be generated after the model finishes its running. Let’s see an example.

runner1(100.)

bp.visualize.line_plot(runner1.mon.ts, runner1.mon.V, show=True)
_images/dc4eba5b0eb5995c1bd71a9d6e37b91b1a56e469685674bc55d9194a3db8001f.png

The monitor in runner1 has recorded the evolution of V. Therefore, it can be accessed by runner1.mon.V or equivalently runner1.mon['V']. Similarly, the recorded trajectory of variable spike can also be obtained through runner1.mon.spike.

runner1.mon.spike
array([[False],
       [False],
       [ True],
       ...,
       [False],
       [False],
       [False]])

Where True indicates a spike is generated at this time step.

The Mechanism of monitors#

No matter we use a list/tuple or instantiate a Monitor class to generate a monitor, we specify the target variables by strings of their names. How does brainpy.Monitor find the target variables through these strings?

Actually, BrainPy first tries to find the target variables in the simulated model by the relative path. If the variables are not found, BrainPy checks whether they can be accessed by the absolute path. If they not found again, an error will be raised.

net = bp.Network(HH(size=10, name='X'), 
                 HH(size=20, name='Y'), 
                 HH(size=30))

# it's ok
bp.StructRunner(net, monitors=['X.V', 'Y.spike']).build_monitors()  
<function brainpy.dyn.runners.ds_runner.DSRunner.build_monitors.<locals>.func(_t, _dt)>

In the above net, there are HH instances named as “X” and “Y”. Therefore, trying to monitor “X.V” and “Y.spike” is successful.

However, in the following example, the node named with “Z” is not accessible in the generated net, and the monitoring setup fails.

z = HH(size=30, name='Z')
net = bp.Network(HH(size=10), HH(size=20))

# node "Z" can not be accessed in the simulation target 'net'
try:
    bp.StructRunner(net, monitors=['Z.V']).build_monitors()  
except Exception as e:
    print(type(e).__name__, ":", e)
RunningError : Cannot find target Z.V in monitor of <brainpy.compact.brainobjects.Network object at 0x000002827DA49E80>, please check.

BrainPy only supports to monitor Variables. Monitoring Variables’ trajectory is meaningful for they are dynamically changed. What is not marked as Variable will be compiled as constants.

try:
    bp.StructRunner(HH(size=1), monitors=['gNa']).build_monitors() 
except Exception as e:
    print(type(e).__name__, ":", e)
RunningError : "gNa" in <brainpy.dyn.neurons.biological_models.HH object at 0x000002827DA5C430> is not a dynamically changed Variable, its value will not change, we think there is no need to monitor its trajectory.

The monitors in BrainPy only record the flattened tensor values. This means if the target variable is a matrix with the shape of (N, M), the resulting trajectory value in the monitor after running T times will be a tensor with the shape of (T, N x M).

class MatrixVarModel(bp.DynamicalSystem):
    def __init__(self, **kwargs):
        super(MatrixVarModel, self).__init__(**kwargs)
        
        self.a = bm.Variable(bm.zeros((4, 4)))
    
    def update(self, _t, _dt):
        self.a += 0.01
        
        
model = MatrixVarModel()
duration = 10
runner = bp.StructRunner(model, monitors=['a'])
runner.run(duration)

print(f'The expected shape of "model.mon.a" is: {(int(duration/bm.get_dt()), model.a.size)}')
print(f'The actual shape of "model.mon.a" is: {runner.mon.a.shape}')
The expected shape of "model.mon.a" is: (500, 16)
The actual shape of "model.mon.a" is: (500, 16)

Monitoring Variables at the Given Indices#

Sometimes we do not care about all the the contents in a variable. We may be only interested in the values at the certain indices. Moreover, for a huge network with a long-time simulation, monitors will comsume a large part of RAM. Therefore, monitoring variables only at the selected indices will be more applicable. BrainPy supports monitoring a part of elements in a Variable with the format of tuple like this:

runner = bp.StructRunner(
    HH(10),
    monitors=['V',  # monitor all values of Variable 'V' 
              ('spike', [1, 2, 3])], # monitor values of Variable at index of [1, 2, 3]
    inputs=('input', 10.)
)
runner.run(100.)

print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)

Or we can use a dictionary to specify the target indices of a variable:

runner = bp.StructRunner(
    HH(10),
    monitors={'V': None,  # 'None' means all values will be monitored
              'spike': [1, 2, 3]},  # specifying the target indices
    inputs=('input', 10.),
)
runner.run(100.)

print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)

Also, we can directly instantiate a brainpy.Monitor class:

runner = bp.StructRunner(
    HH(10),
    monitors=bp.Monitor(variables=['V', ('spike', [1, 2, 3])]), 
    inputs=('input', 10.),
)
runner.run(100.)

print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)
runner = bp.StructRunner(
    HH(10),
    monitors=bp.Monitor(variables={'V': None, 'spike': [1, 2, 3]}),
    inputs=('input', 10.),
)
runner.run(100.)

print(f'The monitor shape of "V" is (run length, variable size) = {runner.mon.V.shape}')
print(f'The monitor shape of "spike" is (run length, index size) = {runner.mon.spike.shape}')
The monitor shape of "V" is (run length, variable size) = (5000, 10)
The monitor shape of "spike" is (run length, index size) = (5000, 3)

Note

Because brainpy.Monitor records a flattened tensor variable, if users want to record a part of a multi-dimentional variable, they must provide the indices corrsponding to the flattened tensor.

Saving and Loading#

@Chaoming Wang

Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.

import brainpy as bp

bp.math.set_platform('cpu')

Saving and loading variables#

Model saving and loading in BrainPy are implemented with .save_states() and .load_states() functions.

BrainPy supports saving and loading model variables with various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

Here’s a simple example:

class EINet(bp.Network):
    def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
        # neurons
        pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
        E = bp.models.LIF(num_exc, **pars, method=method)
        I = bp.models.LIF(num_inh, **pars, method=method)
        E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
        I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

        # synapses
        E2E = bp.models.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        E2I = bp.models.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        I2E = bp.models.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)
        I2I = bp.models.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)

        super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
        
        
net = EINet()
import os
if not os.path.exists('./data'): 
    os.makedirs('./data')
# model saving

net.save_states('./data/net.h5')
# model loading

net.load_states('./data/net.h5')
  • .save_states(filename, all_var=None) function receives a string to specify the output file name. If all_vars is not provided, BrainPy will retieve all variables in the model though the relative path.

  • .load_states(filename, verbose, check_missing) function receives several arguments. The first is a string of the output file name. The second “verbose” specifies whether report the loading progress. The final argument “check_missing” will warn the variables of the model which missed in the output file.

# model loading with warning and checking

net.load_states('./data/net.h5', verbose=True, check_missing=True)
WARNING:brainpy.base.io:There are variable states missed in ./data/net.h5. The missed variables are: ['ExpCOBA0.pre.V', 'ExpCOBA0.pre.input', 'ExpCOBA0.pre.refractory', 'ExpCOBA0.pre.spike', 'ExpCOBA0.pre.t_last_spike', 'ExpCOBA1.pre.V', 'ExpCOBA1.pre.input', 'ExpCOBA1.pre.refractory', 'ExpCOBA1.pre.spike', 'ExpCOBA1.pre.t_last_spike', 'ExpCOBA1.post.V', 'ExpCOBA1.post.input', 'ExpCOBA1.post.refractory', 'ExpCOBA1.post.spike', 'ExpCOBA1.post.t_last_spike', 'ExpCOBA2.pre.V', 'ExpCOBA2.pre.input', 'ExpCOBA2.pre.refractory', 'ExpCOBA2.pre.spike', 'ExpCOBA2.pre.t_last_spike', 'ExpCOBA2.post.V', 'ExpCOBA2.post.input', 'ExpCOBA2.post.refractory', 'ExpCOBA2.post.spike', 'ExpCOBA2.post.t_last_spike', 'ExpCOBA3.pre.V', 'ExpCOBA3.pre.input', 'ExpCOBA3.pre.refractory', 'ExpCOBA3.pre.spike', 'ExpCOBA3.pre.t_last_spike'].
Loading E.V ...
Loading E.input ...
Loading E.refractory ...
Loading E.spike ...
Loading E.t_last_spike ...
Loading ExpCOBA0.g ...
Loading ExpCOBA0.pre_spike.data ...
Loading ExpCOBA0.pre_spike.in_idx ...
Loading ExpCOBA0.pre_spike.out_idx ...
Loading ExpCOBA1.g ...
Loading ExpCOBA1.pre_spike.data ...
Loading ExpCOBA1.pre_spike.in_idx ...
Loading ExpCOBA1.pre_spike.out_idx ...
Loading ExpCOBA2.g ...
Loading ExpCOBA2.pre_spike.data ...
Loading ExpCOBA2.pre_spike.in_idx ...
Loading ExpCOBA2.pre_spike.out_idx ...
Loading ExpCOBA3.g ...
Loading ExpCOBA3.pre_spike.data ...
Loading ExpCOBA3.pre_spike.in_idx ...
Loading ExpCOBA3.pre_spike.out_idx ...
Loading I.V ...
Loading I.input ...
Loading I.refractory ...
Loading I.spike ...
Loading I.t_last_spike ...

Note

By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned TensorCollector. Therefore, there will always be missing keys when loading the variables.

Custom saving and loading#

You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through .vars(). Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of Variables.

The only gotcha to pay attention to is to avoid saving duplicated variables.

Variables#

@Chaoming Wang @Xiaoyu Chen

In BrainPy, the JIT compilation for class objects relies on Variables. In this section, we are going to understand:

  • What is a Variable?

  • What are the subtypes of Variable?

  • How to update a Variable?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a tensor. It stores a tensor as its value. The data in a Variable can be changed during JIT compilation. If a tensor is labeled as a Variable, it means that it is a dynamical variable that changes over time.

Tensors that are not marked as Variables will be JIT compiled as static data. Modifications of these tensors will be invalid or cause an error.

  • Creating a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable. For example:

b1 = bm.random.random(5)
b1
JaxArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
b2 = bm.Variable(b1)
b2
Variable([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
  • Accessing the value in a Variable

The data in a Variable can be obtained through .value.

b2.value
DeviceArray([0.9116168 , 0.6901083 , 0.43920577, 0.13220644, 0.771458  ],            dtype=float32)
(b2.value == b1).all()
DeviceArray(True, dtype=bool)
  • Supported operations on Variables

Variables support almost all the operations for tensors. Actually, brainpy.math.Variable is a subclass of brainpy.math.ndarray.

isinstance(b2, bm.ndarray)
True
isinstance(b2, bm.JaxArray)
True
# `bp.math.ndarray` is an alias for `bp.math.JaxArray` in 'jax' backend

bm.ndarray is bm.JaxArray
True

Note

After performing any operation on a Variable, the resulting value will be a JaxArray (brainpy.math.ndarray is an alias for brainpy.math.JaxArray). This means that the Variable can only be used to refer to a single value.

b2 + 1.
JaxArray([1.9116168, 1.6901083, 1.4392058, 1.1322064, 1.771458 ], dtype=float32)
b2 ** 2
JaxArray([0.8310452 , 0.47624946, 0.1929017 , 0.01747854, 0.5951475 ],            dtype=float32)
bm.floor(b2)
JaxArray([0., 0., 0., 0., 0.], dtype=float32)

Subtypes of Variable#

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar, brainpy.math.Parameter, and brainpy.math.RandomState. Subtypes can also be customized and extended by users.

1. TrainVar#

brainpy.math.TrainVar is a trainable variable and a subclass of brainpy.math.Variable. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.

b = bm.random.rand(4)

b
JaxArray([0.59062696, 0.618052  , 0.84173155, 0.34012556], dtype=float32)
bm.TrainVar(b)
TrainVar([0.59062696, 0.618052  , 0.84173155, 0.34012556], dtype=float32)

2. Parameter#

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subsets method (please see Base class).

b = bm.random.rand(1)

b
JaxArray([0.14782536], dtype=float32)
bm.Parameter(b)
Parameter([0.14782536], dtype=float32)

3. RandomState#

brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.

state = bm.random.RandomState(seed=1234)

state
RandomState([   0, 1234], dtype=uint32)
# perform a "random" sampling 
state.random(1)

state  # the value changed
RandomState([2113592192, 1902136347], dtype=uint32)
# perform a "sample" sampling 
state.sample(1)

state  # the value changed too
RandomState([1076515368, 3893328283], dtype=uint32)

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
DeviceArray([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
DeviceArray([[4198471980, 1111166693],
             [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
DeviceArray([[3244149147, 2659778815],
             [2548793527, 3057026599],
             [ 874320145, 4142002431],
             [3368470122, 3462971882],
             [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.random module: DEFAULT.

bm.random.DEFAULT
RandomState([601887926, 339370966], dtype=uint32)

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState([     0, 654321], dtype=uint32)

In-place updating#

In BrainPy, the transformations (like JIT) usually need to update variables or tensors in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.

For example, here we have a variable a.

a = bm.Variable(bm.zeros(5))

The ids of the variable and the data stored in the variable are:

id_of_a = id(a)
id_of_data = id(a.value)

assert id_of_a != id_of_data

print('id(a)       = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a)       =  2101001001088
id(a.value) =  2101018127136

In-place update (here we use [:]) does not change the pointer refered to the variable but changes its data:

a[:] = 1.

print('id(a)       = ', id(a))
print('id(a.value) = ', id(a.value))
id(a)       =  2101001001088
id(a.value) =  2101019514880
print('(id(a) == id_of_a)          =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a)          = True
(id(a.value) == id_of_data) = False

However, once you do not use in-place operators to assign data, the id that the variable a refers to will change. This will cause serious errors when using transformations in BrainPy.

a = 10.

print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) =  2101001187280
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.JaxArray``. 

Here, we list several commonly used in-place operators.

v = bm.Variable(bm.arange(10))
old_id = id(v)

def check_no_change(new_v):
    assert id(new_v) == old_id, 'Variable has been changed.'

1. Indexing and slicing#

Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.

Indexing: v[i] = a or v[(1, 3)] = c (index multiple values)

v[0] = 1

check_no_change(v)

Slicing: v[i:j] = b

v[1: 2] = 1

check_no_change(v)

Slicing all values: v[:] = d, v[...] = e

v[:] = 0

check_no_change(v)
v[...] = bm.arange(10)

check_no_change(v)

2. Augmented assignment#

All augmented assignment are in-place operations, which include

  • add: +=

  • subtract: -=

  • divide: /=

  • multiply: *=

  • floor divide: //=

  • modulo: %=

  • power: **=

  • and: &=

  • or: |=

  • xor: ^=

  • left shift: <<=

  • right shift: >>=

v += 1

check_no_change(v)
v *= 2

check_no_change(v)
v |= bm.random.randint(0, 2, 10)

check_no_change(v)
v **= 2

check_no_change(v)
v >>= 2

check_no_change(v)

3. .value assignment#

Another way to in-place update a variable is to assign new data to .value. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.

v.value = bm.arange(10)

check_no_change(v)
try:
    v.value = bm.asarray(1.)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
    v.value = bm.random.random(10)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

4. .update() method#

Actually, the .value assignment is the same operation as the .update() method. Users who want a safe assignment can choose this method too.

v.update(bm.random.randint(0, 20, size=10))
try:
    v.update(bm.asarray(1.))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got ().
try:
    v.update(bm.random.random(10))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

Base Class#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about:

  • The Base class for the BrainPy ecosystem

  • The Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.Base#

The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.Base.

A Base object can have many variables, children Base objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.

class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

The naming system#

Every Base object has a unique name. Users can specify a unique name when you instantiate a Base class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)

If a name is not specified to the Base oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions#

Three important collection functions are implemented for each Base object. Specifically, they are:

  • nodes(): to collect all instances of Base objects, including children nodes in a node.

  • vars(): to collect all variables defined in the Base node and in its children nodes.

fhn = FHN(10)

All variables in a Base object can be collected through Base.vars(). The returned container is a TensorCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method Base.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>,
 'FHN2': <__main__.FHN at 0x2155a7a65e0>}
type(nodes)
brainpy.base.collector.Collector

All integrators can be collected by:

ints = fhn.nodes().subset(bp.integrators.Integrator)

ints
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>}
type(ints)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        self.conn = bm.ones((num1, num2), dtype=bool) * w
        bm.fill_diagonal(self.conn, 0.)

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

And retrieve all nodes (instances of the Base class) by .nodes():

net.nodes()
{'FHN3': <__main__.FHN at 0x2155ace3130>,
 'FHN4': <__main__.FHN at 0x2155a798d30>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x2155a743c40>}

If we only care about a subtype of class, we can retrieve them through:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
Absolute paths#

It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.

For absolute paths, all keys in the resulting Collector (Base.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK45', 'FHN2')

Calling .nodes() returns their names and models.

fhn.nodes().keys()
dict_keys(['RK45', 'FHN2'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Relative paths#

Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x2155ace3130>

Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x2155a743c40>,
 'pre': <__main__.FHN at 0x2155ace3130>,
 'post': <__main__.FHN at 0x2155a798d30>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x2155ace3130>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>}

Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net include:

net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'post.spike': Variable([False, False, False, False, False], dtype=bool),
 'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

While variables relatively accessed from net.post are:

net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'spike': Variable([False, False, False, False, False], dtype=bool),
 'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Elements in containers#

One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.

class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x2155ae60a00>}

To solve this problem, BrianPy provides implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars and integrators and nodes registered in implicit_nodes can be retrieved by collection functions.

class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)}  # the input must be a dict
                                    )
        self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()}  # the input must be a dict
                                     )
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()
{'T1': <__main__.FHN at 0x2155ae6bca0>,
 'T2': <__main__.FHN at 0x2155ae6b3a0>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae6f8e0>,
 'RK411': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae66610>,
 'AnotherTest0': <__main__.AnotherTest at 0x2155ae6f0d0>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.

t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T1.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'T2.spike': Variable([False, False, False, False, False], dtype=bool),
 'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32)}

Saving and loading#

Because Base.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by Base.save_states() and Base.load_states().

Save#
Base.save_states(PATH, [vars])

Models exported from BrainPy support various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
Load#

Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector#

Collection functions return an brainpy.Collector that is a dictionary mapping names to elements. It has some useful methods.

subset()#

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, Base.nodes() returns all instances of Base class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches the given cls.

unique()#

It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.

Therefore, the Collector provides Collector.unique() to handle this problem automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x2155b13b550>,
 'A': <__main__.ModelA at 0x2155b13b460>,
 'A_shared': <__main__.SharedA at 0x2155a7a6580>,
 'A_shared.source': <__main__.ModelA at 0x2155b13b460>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x2155b13b550>,
 'A': <__main__.ModelA at 0x2155b13b460>,
 'A_shared': <__main__.SharedA at 0x2155a7a6580>}

update()#

The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key is mapped to a different value. If it is, an error will occur.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()#

Collector.replace(old_key, new_value) is used to update the value of a key.

tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))

tc
{'a': JaxArray([1., 1., 1.], dtype=float32)}

__add()__#

Two Collectors can be merged.

a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})

a + b
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'b': JaxArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}

TensorCollector#

TensorCollector is subclass of Collector, but it is specifically to collect tensors.

Compilation#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about code compilation that can accelerate your model running performance.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.math.jit()#

JAX provides JIT compilation jax.jit() for pure functions.In most cases, however, we code with Python classes. brainpy.math.jit() is intended to extend just-in-time compilation to class objects.

JIT compilation for class objects#

The constraints for class-object JIT ciompilation include:

  • The JIT target must be a subclass of brainpy.Base.

  • Dynamically changed variables must be labeled as brainpy.math.Variable.

  • Updating Variables must be accomplished by in-place operations.

class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters
        self.dimension = dimension

        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u 
        # The above line can also be expressed as: 
        # 
        #   self.w.value = self.w - u 
        # 
        # or, 
        # 
        #   self.w.update(self.w - u)

In this example, weight self.w is a dynamically changed variable, thus marked as Variable. During the update phase __call__(), self.w is in-place updated through self.w[:] = .... Alternatively, one can replace the data in the variable by self.w.value = ... or self.w.update(...).

Now this logistic regression can be accelerated by JIT compilation.

num_dim, num_points = 10, 200000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
lr = LogisticRegression(10)

%timeit lr(points, labels)
3.73 ms ± 589 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
lr_jit = bm.jit(lr)

%timeit lr_jit(points, labels)
1.75 ms ± 57.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

JIT mechanism#

The mechanism of JIT compilation is that BrainPy automatically transforms your class methods into functions.

brainpy.math.jit() receives a dyn_vars argument, which denotes the dynamically changed variables. If it is not provided by users, BrainPy will automatically detect them by calling Base.vars() (only variables labeled as Variable will be automatically retrieved by Base.vars()). Once receiving “dyn_vars”, BrainPy will treat “dyn_vars” as function arguments and then transform class objects into functions.

import types

isinstance(lr_jit, types.FunctionType)  # "lr" is class, while "lr_jit" is a function
True

Therefore, the secrete of brainpy.math.jit() is providing “dyn_vars”. No matter your target is a class object, a method in the class object, or a pure function, if there are dynamically changed variables, you just pack them into brainpy.math.jit() as “dyn_vars”. Then, all the compilation and acceleration will be handled by BrainPy automatically. Let’s illustrate this by several examples.

Example 1: JIT compiled methods in a class#

In this example, we try to run a method just-in-time in a class, in which the object variable are used to compute the final results.

class Linear(bp.Base):
    def __init__(self, n_in, n_out):
        super(Linear, self).__init__()
        self.w = bm.random.random((n_in, n_out))
        self.b = bm.zeros(n_out)
    
    def update(self, x):
        return x @ self.w + self.b
x = bm.zeros(10)  # the input data
l = Linear(10, 3)  # the class we need

First, we mark “w” and “b” as dynamically changed variables. Changing “w” or “b” will change the final results.

update1 = bm.jit(
    l.update, dyn_vars=[l.w, l.b]  # make 'w' and 'b' dynamically change
)  

update1(x)  # x is 0., b is 0., therefore y is 0.
JaxArray([0., 0., 0.], dtype=float32)
l.b[:] = 1.  # change b to 1, we expect y will be 1 too

update1(x)
JaxArray([1., 1., 1.], dtype=float32)

This time, we only mark “w” as a dynamically changed variable. We will find that no matter how “b” is modified, the results will not change.

update2 = bm.jit(
    l.update, dyn_vars=[l.w]  # only make 'w' dynamically change
)

update2(x)
JaxArray([1., 1., 1.], dtype=float32)
l.b[:] = 2.  # change b to 2

update2(x)  # while y will not be 2
JaxArray([1., 1., 1.], dtype=float32)
Example 2: JIT compiled functions#

Now, we change the above “Linear” object to a function.

n_in = 10;  n_out = 3

w = bm.random.random((n_in, n_out))
b = bm.zeros(n_out)

def update(x):
    return x @ w + b

If we do not provide dyn_vars, “w” and “b” will be compiled as constant values.

update1 = bm.jit(update)
update1(x)
JaxArray([0., 0., 0.], dtype=float32)
b[:] = 1.  # modify the value of 'b' will not 
           # change the result, because in the 
           # jitted function, 'b' is already 
           # a constant
update1(x)
JaxArray([0., 0., 0.], dtype=float32)

Providing “w” and “b” as dyn_vars will make them dynamically changed again.

update2 = bm.jit(update, dyn_vars=(w, b))
update2(x)
JaxArray([1., 1., 1.], dtype=float32)
b[:] = 2.  # change b to 2, while y will not be 2
update2(x)
JaxArray([2., 2., 2.], dtype=float32)
Example 3: JIT compiled neural networks#

Now, let’s use SGD to train a neural network with JIT acceleration. Here we use the autograd function brainpy.math.grad(), which will be discussed in detail in the next section.

class LinearNet(bp.Base):
    def __init__(self, n_in, n_out):
        super(LinearNet, self).__init__()

        # weights
        self.w = bm.TrainVar(bm.random.random((n_in, n_out)))
        self.b = bm.TrainVar(bm.zeros(n_out))
        self.r = bm.TrainVar(bm.random.random((n_out, 1)))
    
    def update(self, x):
        h = x @ self.w + self.b
        return h @ self.r
    
    def loss(self, x, y):
        predict = self.update(x)
        return bm.mean((predict - y) ** 2)


ln = LinearNet(100, 200)

# provide the variables want to update
opt = bm.optimizers.SGD(lr=1e-6, train_vars=ln.vars()) 

# provide the variables require graidents
f_grad = bm.grad(ln.loss, grad_vars=ln.vars(), return_value=True)  


def train(X, Y):
    grads, loss = f_grad(X, Y)
    opt.update(grads)
    return loss

# JIT the train function 
train_jit = bm.jit(train, dyn_vars=ln.vars() + opt.vars())
xs = bm.random.random((1000, 100))
ys = bm.random.random((1000, 1))

for i in range(30):
    loss  = train_jit(xs, ys)
    print(f'Train {i}, loss = {loss:.2f}')
Train 0, loss = 6649731.50
Train 1, loss = 3748688.50
Train 2, loss = 2126231.00
Train 3, loss = 1210147.88
Train 4, loss = 690106.50
Train 5, loss = 393984.28
Train 6, loss = 225071.75
Train 7, loss = 128625.49
Train 8, loss = 73524.97
Train 9, loss = 42035.37
Train 10, loss = 24035.91
Train 11, loss = 13746.33
Train 12, loss = 7863.82
Train 13, loss = 4500.70
Train 14, loss = 2577.91
Train 15, loss = 1478.59
Train 16, loss = 850.07
Train 17, loss = 490.72
Train 18, loss = 285.26
Train 19, loss = 167.80
Train 20, loss = 100.63
Train 21, loss = 62.24
Train 22, loss = 40.28
Train 23, loss = 27.73
Train 24, loss = 20.55
Train 25, loss = 16.45
Train 26, loss = 14.10
Train 27, loss = 12.76
Train 28, loss = 11.99
Train 29, loss = 11.56

RandomState#

We have talked about RandomState in the Variables section. RandomeState is also a Variable. Therefore, if the default RandomState (brainpy.math.random.DEFAULT) is used in your function, you should mark it as one of the dyn_vars in the function. Otherwise, they will be treated as constants and the jitted function will always return the same value.

def function():
    return bm.random.normal(0, 1, size=(10,))
f1 = bm.jit(function)

f1() == f1()
JaxArray([ True,  True,  True,  True,  True,  True,  True,  True,
           True,  True], dtype=bool)

The correct way to make JIT for this function is:

bm.random.seed(1234)

f2 = bm.jit(function, dyn_vars=bm.random.DEFAULT)

f2() == f2()
JaxArray([False, False, False, False, False, False, False, False,
          False, False], dtype=bool)

Static arguments#

Static arguments are treated as static/constant in the jitted function.

Two things must be marked as static: numerical arguments used in the conditional syntax (bool values or resulting in bool values) and strings. Otherwise, an error will raise.

@bm.jit
def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x
try:
    f(1.)
except Exception as e:
    print(type(e), e)
<class 'jax._src.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at C:\Users\adadu\AppData\Local\Temp\ipykernel_44816\1408095738.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Simply speaking, arguments resulting in boolean values must be declared as static arguments. In brainpy.math.jit() function, we can set the names of static arguments.

def f(x):
  if x < 3:  # this will cause error
    return 3. * x ** 2
  else:
    return -4 * x

f_jit = bm.jit(f, static_argnames=('x',))
f_jit(x=1.)
DeviceArray(3., dtype=float32, weak_type=True)

However, it’s worth noting that calling the jitted function with different values for these static arguments will trigger recompilation. Therefore, declaring static arguments may be suitable to the following situations:

  1. Boolean arguments.

  2. Arguments that only have several possible values.

If the argument value change significantly, you’d better not declare it as static.

For more information, please refer to the jax.jit API.

Differentiation#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about how to realize automatic differentiation on your variables in a function or a class object. In current machine learning systems, gradients are commonly used in various situations. Therefore, we should understand:

  • How to calculate derivatives of arbitrary complex functions?

  • How to compute high-order gradients?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Preliminary#

Every autograd function in BrainPy has several keywords. All examples below are illustrated through brainpy.math.grad(). Other autograd functions have the same settings.

argnums and grad_vars#

The autograd functions in BrainPy can compute derivatives of function arguments (specified by argnums) or non-argument variables (specified by grad_vars). For instance, the following is a linear readout model:

class Linear(bp.Base):
    def __init__(self):
        super(Linear, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        r = bm.dot(self.w, x) + self.b
        return r.sum()
    
l = Linear()

If we try to focus on the derivative of the argument “x” when calling the update function, we can set this through argnums:

grad = bm.grad(l.update, argnums=0)

grad(bm.ones(10))
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)

By contrast, if you focus on the derivatives of parameters “self.w” and “self.b”, we should label them with grad_vars:

grad = bm.grad(l.update, grad_vars=(l.w, l.b))

grad(bm.ones(10))
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
 DeviceArray([1.], dtype=float32))

If we pay attention to the derivatives of both argument “x” and parameters “self.w” and “self.b”, argnums and grad_vars can be used together. In this condition, the gradient function will return gradients with the format of (var_grads, arg_grads), where arg_grads refers to the gradients of “argnums” and var_grads refers to the gradients of “grad_vars”.

grad = bm.grad(l.update, grad_vars=(l.w, l.b), argnums=0)

var_grads, arg_grads = grad(bm.ones(10))
var_grads
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
 DeviceArray([1.], dtype=float32))
arg_grads
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)

return_value#

As is mentioned above, autograd functions return a function which computes gradients regardless of the returned value. Sometimes, however, we care about the value the function returns, not just the gradients. In this condition, you can set return_value=True in the autograd function.

grad = bm.grad(l.update, argnums=0, return_value=True)

gradient, value = grad(bm.ones(10))
gradient
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)
value
DeviceArray(4.6318984, dtype=float32)

has_aux#

In some situations, we are interested in the intermediate values in a function, and has_aux=True can be of great help. The constraint is that you must return values with the format of (loss, aux_data). For instance,

class LinearAux(bp.Base):
    def __init__(self):
        super(LinearAux, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        dot = bm.dot(self.w, x)
        r = (dot + self.b).sum()
        return r, (r, dot)  # here the aux data is a tuple, includes the loss and the dot value.
                            # however, aux can be arbitrary complex.
    
l2 = LinearAux()
grad = bm.grad(l2.update, argnums=0, has_aux=True)

gradient, aux = grad(bm.ones(10))
gradient
JaxArray([0.20289445, 0.4745227 , 0.36053288, 0.94524395, 0.8360598 ,
          0.06507981, 0.7748591 , 0.8377187 , 0.5767547 , 0.47604012],            dtype=float32)
aux
(DeviceArray(5.5497055, dtype=float32), JaxArray([5.5497055], dtype=float32))

When multiple keywords (argnums, grad_vars, has_aux orreturn_value) are set simulatenously, the return format of the gradient function can be inspected through the corresponding API documentation brainpy.math.grad().

brainpy.math.grad()#

brainpy.math.grad() takes a function/object (\(f : \mathbb{R}^n \to \mathbb{R}\)) as the input and returns a new function (\(\partial f(x) \to \mathbb{R}^n\)) which computes the gradient of the original function/object. It’s worthy to note that brainpy.math.grad() only supports returning scalar values.

Pure functions#

For pure function, the gradient is taken with respect to the first argument:

def f(a, b):
    return a * 2 + b

grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32, weak_type=True)

However, this can be controlled via the argnums argument.

grad_f2 = bm.grad(f, argnums=(0, 1))

grad_f2(2., 1.)
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(1., dtype=float32, weak_type=True))

Class objects#

For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars and argnums setting:

class F(bp.Base):
    def __init__(self):
        super(F, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        vv = ab2 + c
        return vv.mean()
    
f = F()

The grad_vars can be a JaxArray, or a list/tuple/dict of JaxArray.

bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': DeviceArray([2.], dtype=float32),
 'F0.b': DeviceArray([2.], dtype=float32)}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
(DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))

If there are dynamically changed values in the gradient function, you can provide them in the dyn_vars argument.

class F2(bp.Base):
    def __init__(self):
        super(F2, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab = ab * 2
        self.a.value = ab
        return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
DeviceArray([2.], dtype=float32)

Besides, if you are interested in the gradient of the input value, please use the argnums argument. Then, the gradient function will return (grads_of_grad_vars, grads_of_args).

class F3(bp.Base):
    def __init__(self):
        super(F3, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c, d):
        ab = self.a * self.b
        ab = ab * 2
        return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : (DeviceArray(3., dtype=float32, weak_type=True), DeviceArray(10., dtype=float32, weak_type=True))

Actually, it is recommended to provide all dynamically changed variables, whether or not they are updated in the gradient function, in the dyn_vars argument.

Auxiliary data#

Usually, we want to get the loss value, or we want to return some intermediate variables during the gradient computation. In these situation, users can set has_aux=True to return auxiliary data and set return_value=True to return the loss value.

# return loss

grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)

print('grad: ', grad)
print('loss: ', loss)
grad:  [2.]
loss:  12.0
class F4(bp.Base):
    def __init__(self):
        super(F4, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        loss = (ab + c).mean()
        return loss, (ab, ab2)
    

f4 = F4()
    
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)

print('grad: ', grad)
print('aux_data: ', aux_data)
grad:  [1.]
aux_data:  (JaxArray([1.], dtype=float32), JaxArray([2.], dtype=float32))
Any function used to compute gradients through ``brainpy.math.grad()`` must return a scalar value. Otherwise an error will raise. 
try:
    bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
    print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output had shape: (2,).
# this is right

bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray([0.5, 0.5], dtype=float32)

brainpy.math.vector_grad()#

If users want to take gradients for a vector-output values, please use the brainpy.math.vector_grad() function. For example,

def f(a, b): 
    return bm.sin(b) * a

Gradients for vectors#

# vectors

a = bm.arange(5.)
b = bm.random.random(5)
bm.vector_grad(f)(a, b)
JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ],            dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ],            dtype=float32),
 JaxArray([0.       , 0.9801371, 1.7597246, 2.741662 , 3.9158623], dtype=float32))

Gradients for matrices#

# matrix

a = bm.arange(6.).reshape((2, 3))
b = bm.random.random((2, 3))
bm.vector_grad(f, argnums=1)(a, b)
JaxArray([[0.       , 0.8662993, 1.1221857],
          [2.9322515, 2.3293345, 3.024507 ]], dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([[0.45055482, 0.49952534, 0.8277529 ],
           [0.21131878, 0.8129499 , 0.79630035]], dtype=float32),
 JaxArray([[0.       , 0.8662993, 1.1221857],
           [2.9322515, 2.3293345, 3.024507 ]], dtype=float32))

Similar to brainpy.math.grad() , brainpy.math.vector_grad() also supports derivatives of variables in a class object. Here is a simple example.

class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.ones(5)
    self.y = bm.ones(5)

  def __call__(self):
    return self.x ** 2 + self.y ** 3 + 10

t = Test()
bm.vector_grad(t, grad_vars=t.x)()
DeviceArray([2., 2., 2., 2., 2.], dtype=float32)
bm.vector_grad(t, grad_vars=(t.x, ))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),)
bm.vector_grad(t, grad_vars=(t.x, t.y))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),
 DeviceArray([3., 3., 3., 3., 3.], dtype=float32))

Other operations like return_value and has_aux in brainpy.math.vector_grad() are the same as those in brainpy.math.grad() .

brainpy.math.jacobian()#

Another way to take gradients of a vector-output value is using brainpy.math.jacobian(). brainpy.math.jacobian() aims to automatically compute the Jacobian matrices \(\partial f(x) \in \mathbb{R}^{m \times n}\) by the given function \(f : \mathbb{R}^n \to \mathbb{R}^m\) at the given point of \(x \in \mathbb{R}^n\). Here, we will not go to the details of the implementation and usage of the brainpy.math.jacobian(). Instead, we only show two examples about the pure function and class function.

Given the following function,

import jax.numpy as jnp

def f1(x, y):
    a = 4 * x[1] ** 2 - 2 * x[2]
    r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
    return r, a
_x = bm.array([1., 2., 3.])
_y = bm.array([10., 5.])
    
grads, vec, aux = bm.jacobian(f1, return_value=True, has_aux=True)(_x, _y)
grads
JaxArray([[10.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        , 25.        ],
          [ 0.        , 16.        , -2.        ],
          [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32)
vec
DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)
aux
DeviceArray(10., dtype=float32)

Given the following class objects,

class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.array([1., 2., 3.])

  def __call__(self, y):
    a = self.x[0] * y[0]
    b = 5 * self.x[2] * y[1]
    c = 4 * self.x[1] ** 2 - 2 * self.x[2]
    d = self.x[2] * jnp.sin(self.x[0])
    r = jnp.asarray([a, b, c, d])
    return r, (c, d)
t = Test()
f_grad = bm.jacobian(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)

(var_grads, arg_grads), value, aux = f_grad(_y)
var_grads
DeviceArray([[10.        ,  0.        ,  0.        ],
             [ 0.        ,  0.        , 25.        ],
             [ 0.        , 16.        , -2.        ],
             [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32)
arg_grads
JaxArray([[ 1.,  0.],
          [ 0., 15.],
          [ 0.,  0.],
          [ 0.,  0.]], dtype=float32)
value
DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)
aux
(DeviceArray(10., dtype=float32), DeviceArray(2.5244129, dtype=float32))

For more details on automatical differentation, please see our API documentation.

Control Flows#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about how to build structured control flows with the BrainPy data structure JaxArray. These control flows include

  • the for loop syntax,

  • the while loop syntax,

  • and the condition syntax.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

In JAX, the control flow syntax must be defined as structured control flows. the JaxArray in BrainPy provides an easier syntax to make control flows.

Note

All the control flow syntax below is not re-implementations of JAX’s API for control flows. We only gurantee the following APIs are useful and intuitive when you use brainpy.math.JaxArray.

brainpy.math.make_loop()#

brainpy.math.make_loop() is used to generate a for-loop function when you use JaxArray.

Suppose that you are using several JaxArrays (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as


def for_loop_function(body_fun, dyn_vars, out_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' and 'out_vars' are updated in 'body_fun()'
    results = body_fun(x)
    ys.append([out_vars, results])
  return ys

In BrainPy, you can define this logic using brainpy.math.make_loop():


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=False)

hist_of_out_vars = loop_fun(xs)

Or,


loop_fun = brainpy.math.make_loop(body_fun, dyn_vars, out_vars, has_return=True)

hist_of_out_vars, hist_of_return_vars = loop_fun(xs)

Let’s implement a recurrent network to illustrate how to use this function.

class RNN(bp.dyn.DynamicalSystem):
  def __init__(self, n_in, n_h, n_out, n_batch, g=1.0, **kwargs):
    super(RNN, self).__init__(**kwargs)

    # parameters
    self.n_in = n_in
    self.n_h = n_h
    self.n_out = n_out
    self.n_batch = n_batch
    self.g = g

    # weights
    self.w_ir = bm.TrainVar(bm.random.normal(scale=1 / n_in ** 0.5, size=(n_in, n_h)))
    self.w_rr = bm.TrainVar(bm.random.normal(scale=g / n_h ** 0.5, size=(n_h, n_h)))
    self.b_rr = bm.TrainVar(bm.zeros((n_h,)))
    self.w_ro = bm.TrainVar(bm.random.normal(scale=1 / n_h ** 0.5, size=(n_h, n_out)))
    self.b_ro = bm.TrainVar(bm.zeros((n_out,)))

    # variables
    self.h = bm.Variable(bm.random.random((n_batch, n_h)))

    # function
    self.predict = bm.make_loop(self.cell,
                                dyn_vars=self.vars(),
                                out_vars=self.h,
                                has_return=True)

  def cell(self, x):
    self.h.value = bm.tanh(self.h @ self.w_rr + x @ self.w_ir + self.b_rr)
    o = self.h @ self.w_ro + self.b_ro
    return o


rnn = RNN(n_in=10, n_h=100, n_out=3, n_batch=5)

In the above RNN model, we define a body function RNN.cell for later for-loop over input values. The loop function is defined as self.predict with bm.make_loop(). We care about the history values of “self.h” and the readout value “o”, so we set out_vars=self.h and has_return=True.

xs = bm.random.random((100, rnn.n_in))
hist_h, hist_o = rnn.predict(xs)
hist_h.shape  # the shape should be (num_time,) + h.shape
(100, 5, 100)
hist_o.shape  # the shape should be (num_time, ) + o.shape
(100, 5, 3)

If you have multiple input values, you should wrap them as a container and call the loop function with loop_fun(xs), where “xs” can be a JaxArray or a list/tuple/dict of JaxArray. For example:

a = bm.Variable(bm.zeros(10))

def body(x):
    x1, x2 = x  # "x" is a tuple/list of JaxArray
    a.value += (x1 + x2)

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs=[bm.arange(10), bm.ones(10)])
Variable([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
          [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
          [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
          [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
          [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
          [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
          [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
          [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
          [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
          [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32)
a = bm.Variable(bm.zeros(10))

def body(x):  # "x" is a dict of JaxArray
    a.value += x['a'] + x['b']

loop = bm.make_loop(body, dyn_vars=[a], out_vars=a)
loop(xs={'a': bm.arange(10), 'b': bm.ones(10)})
Variable([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
          [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
          [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
          [10., 10., 10., 10., 10., 10., 10., 10., 10., 10.],
          [15., 15., 15., 15., 15., 15., 15., 15., 15., 15.],
          [21., 21., 21., 21., 21., 21., 21., 21., 21., 21.],
          [28., 28., 28., 28., 28., 28., 28., 28., 28., 28.],
          [36., 36., 36., 36., 36., 36., 36., 36., 36., 36.],
          [45., 45., 45., 45., 45., 45., 45., 45., 45., 45.],
          [55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]],            dtype=float32)

dyn_vars, out_vars, xs and the body function returns can be arrays with the container structure like tuple/list/dict. The history output values will preserve the container structure of out_varsand body function returns. If has_return=True, the loop function will return a tuple of (hist_of_out_vars, hist_of_fun_returns). If no values are interested, please set out_vars=None, and the loop function only returns hist_of_out_vars.

brainpy.math.make_while()#

brainpy.math.make_while() is used to generate a while-loop function when you use JaxArray. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.make_while() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which does not support returned values:


while cond_fun(x):
    body_fun(x)

where x is the external input that is not iterated. All the iterated variables should be marked as JaxArray. All JaxArrays used in cond_fun and body_fun should be declared as dyn_vars variables.

Let’s look an example:

i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))

def cond_f(x): 
    return i[0] < 10

def body_f(x):
    i.value += 1.
    counter.value += i

loop = bm.make_while(cond_f, body_f, dyn_vars=[i, counter])

In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i and counter.

loop()
counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)

brainpy.math.make_cond()#

brainpy.math.make_cond() is used to generate a condition function you use JaxArray. It supports the following conditional logic:


if True:
    true statements 
else: 
    false statements

When using brainpy.math.make_cond() , true statements should be wrapped as a true_fun function which implements logics under true assertion, and false statements should be wrapped as a false_fun function which implements logics under false assertion. Neither function supports returning values.


if True:
    true_fun(x)
else:
    false_fun(x)

All the JaxArrays used in true_fun and false_fun should be declared in the dyn_vars argument. x is used to receive the external input value.

Let’s make a try:

a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))

def true_f(x):  a.value += 1

def false_f(x): b.value -= 1

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])

Here, we have two tensors. If true, tensor a is added by 1; if false, tensor b is subtracted by 1.

cond(pred=True)

a, b
(Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(True)

a, b
(Variable([2., 2.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False)

a, b
(Variable([2., 2.], dtype=float32), Variable([0., 0.], dtype=float32))
cond(False)

a, b
(Variable([2., 2.], dtype=float32), Variable([-1., -1.], dtype=float32))

Or, we define a conditional case which depends on the external input.

a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))

def true_f(x):  a.value += x

def false_f(x): b.value -= x

cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
cond(True, 10.)

a, b
(Variable([10., 10.], dtype=float32), Variable([1., 1.], dtype=float32))
cond(False, 5.)

a, b
(Variable([10., 10.], dtype=float32), Variable([-4., -4.], dtype=float32))

Low-level Operator Customization#

@Tianqiu Zhang

BrainPy is built on Jax and can accelerate model running performance based on Just-in-Time(JIT) compilation. In order to enhance performance on CPU and GPU, we publish another package BrainPyLib to provide several built-in low-level operators in synaptic computation. These operators are written in C++ and wrapped as Jax primitives by using XLA. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce numba.cfunc here and provide convenient interfaces for users to customize operators without touching the underlying logic.

import brainpy as bp
import brainpy.math as bm
from jax import jit
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray

bm.set_platform('cpu')

In Computation with Sparse Connections section, we formally discuss the benefits of computation with our built-in operators. These operators are provided by brainpylib package and can be accessed through brainpy.math module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++ and converted into Jax/XLA compatible primitive by using Pybind11.

It is not easy to write a C++ operator and implement a series of conversion. Users have to learn how to write a C++ operator, how to write a customized Jax primitive, and how to convert your C++ operator into a Jax primitive. Here are some links for users who prefer to dive into the details: Jax primitives, XLA custom calls.

However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface register_op to register customized operators on CPU and GPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of numba.cfunc, which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Parameters and return types of register_op is listed in this api docs. Here is an example of using register_op on CPU.

How to customize operators?#

CPU version#

First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up numba documentations for details.

def custom_op(outs, ins):
  y, y1 = outs
  x, x2 = ins
  y[:] = x + 1
  y1[:] = x2 + 2

There are some restrictions that users should know:

  • Parameters of the operators are outs and ins, corresponding to output variable(s) and input variable(s). The order cannot be changed.

  • The function cannot have any return value.

  • Notice that in GPU version users should write kernel function according to numba cuda.jit documentation. When applying CPU function to GPU, users only need to implement CPU operators.

Then users should describe the shapes and types of the outputs, because jax/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:

  • a ShapedArray,

  • a sequence of ShapedArray,

  • a function, it should return correct output shapes of ShapedArray.

Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible.

def abs_eval_1(*ins):
  # ins: inputs arguments, only shapes and types are accessible.
  # Because custom_op outputs shapes and types are exactly the
  # same as inputs, so here we can only return ordinary inputs.
  return ins

The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know abs_eval_1 and abs_eval_2 are doing the same thing.

def abs_eval_2(*ins):
  return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)

Now we have prepared for registering a CPU operator. register_op will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:

  • op_name: Name of the operator.

  • cpu_func: Customized operator of CPU version.

  • out_shapes: The shapes and types of the outputs.

z = jnp.ones((1, 2), dtype=jnp.float32)
# Users could try out_shapes=abs_eval_2 and see if the result is different
op = bm.register_op(
  op_name='add',
  cpu_func=custom_op,
  out_shapes=abs_eval_1,
  apply_cpu_func_to_gpu=False)
jit_op = jit(op)
print(jit_op(z, z))
[DeviceArray([[2., 2.]], dtype=float32), DeviceArray([[3., 3.]], dtype=float32)]

GPU version#

We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:

  • gpu_func: Customized operator of CPU version.

  • apply_cpu_func_to_gpu: Whether to run kernel function on CPU for an alternative way for GPU version.

Warning

GPU operators will be wrapped by cuda.jit in numba, but numba currently is not support to launch CUDA kernels from cfuncs. For this reason, gpu_func is none for default, and there will be an error if users pass a gpu operator to gpu_func.

Therefore, BrainPy enables users to set apply_cpu_func_to_gpu to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks.

Performance#

To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use event_sum as an example. The implementation of event_sum by using our customization is shown as below:

def abs_eval(events, indices, indptr, post_size, values):
  return post_size


def event_sum_op(outs, ins):
  post_val = outs
  events, indices, indptr, post_size, values = ins

  for i in range(len(events)):
      if events[i]:
        for j in range(indptr[i], indptr[i+1]):
          index = indices[j]
          old_value = post_val[index]
          post_val[index] = values + old_value


event_sum = bm.register_op(op_name='event_sum', cpu_func=event_sum_op, out_shapes=abs_eval)
jit_event_sum = jit(event_sum)

Exponential COBA will be our benchmark for testing the speed. We will use built-in operator event_sum first.

class ExpCOBA(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
               method='exp_auto'):
    super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def update(self, _t, _dt):
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # Built-in operator
    # --------------------------------------------------------------------------------------
    self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
    # --------------------------------------------------------------------------------------
    self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.dyn.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = bp.models.LIF(num_exc, **pars, method=method)
    I = bp.models.LIF(num_inh, **pars, method=method)
    E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
    I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
    I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)


net = EINet(scale=10., method='euler')
# simulation
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.628559827804565

The total time is 15.62 seconds. Next we use our customized operator.

class ExpCOBA(bp.dyn.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
               method='exp_auto'):
    super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def update(self, _t, _dt):
    self.g.value = self.integral(self.g, _t, dt=_dt)
    post_size = bm.zeros(self.post.num)
    # Customized operator
    # ------------------------------------------------------------------------------------------------------------
    self.g += jit_event_sum(self.pre.spike, self.pre2post[0].value, self.pre2post[1].value, post_size, self.g_max)
    # ------------------------------------------------------------------------------------------------------------
    self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.dyn.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = bp.models.LIF(num_exc, **pars, method=method)
    I = bp.models.LIF(num_inh, **pars, method=method)
    E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
    I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method)
    I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)
    I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)


net = EINet(scale=10., method='euler')
runner = bp.dyn.DSRunner(net, inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(10000.)
print(t)
15.703513145446777

After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss.

Interoperation with other JAX frameworks#

import brainpy.math as bm

BrainPy can be easily interoperated with other JAX frameworks.

1. data are exchangeable in different frameworks.#

This can be realized because JaxArray can be direactly converted to JAX ndarray or NumPy ndarray.

Convert a JaxArray into a JAX ndarray.

# JaxArray.value is a JAX ndarray
b.value
DeviceArray([5, 1, 2, 3, 4], dtype=int32)

Convert a JaxArray into a numpy ndarray.

# JaxArray can be easily converted to a numpy ndarray
np.asarray(b)
array([5, 1, 2, 3, 4])

Convert a numpy ndarray into a JaxArray.

bm.asarray(np.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

Convert a JAX ndarray into a JaxArray.

import jax.numpy as jnp
bm.asarray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
bm.JaxArray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

2. transformations in brainpy.math also work on functions.#

APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.

import optax
# First create several useful functions.

network = bm.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

@bm.jit
def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
# Generate some data

bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'

brainpy.base module#

The base module for whole BrainPy ecosystem.

  • This module provides the most fundamental class Base, and its associated helper class Collector and ArrayCollector.

  • For each instance of “Base” class, users can retrieve all the variables (or trainable variables), integrators, and nodes.

  • This module also provides a Function class to wrap user-defined functions. In each function, maybe several nodes are used, and users can initialize a Function by providing the nodes used in the function. Unfortunately, Function class does not have the ability to gather nodes automatically.

  • This module provides io helper functions to help users save/load model states, or share user’s customized model with others.

  • This module provides naming tools to guarantee the unique nameing for each Base object.

Details please see the following.

Base Class#

Base([name])

The Base class for whole BrainPy ecosystem.

Function Wrapper#

Function(f[, nodes, dyn_vars, name])

The wrapper for Python functions.

Collectors#

Collector

A Collector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

TensorCollector

A ArrayCollector is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

Exporting and Loading#

save_as_h5(filename, variables)

Save variables into a HDF5 file.

save_as_npz(filename, variables[, compressed])

Save variables into a numpy file.

save_as_pkl(filename, variables)

Save variables into a pickle file.

save_as_mat(filename, variables)

Save variables into a HDF5 file.

load_by_h5(filename, target[, verbose])

Load variables in a HDF5 file.

load_by_npz(filename, target[, verbose])

Load variables from a numpy file.

load_by_pkl(filename, target[, verbose])

Load variables from a pickle file.

load_by_mat(filename, target[, verbose])

Load variables from a numpy file.

SUPPORTED_FORMATS

Built-in mutable sequence.

Naming Tools#

check_name_uniqueness(name, obj)

Check the uniqueness of the name for the object type.

get_unique_name(type_)

Get the unique name for the given object type.

clear_name_cache()

Clear the cached names.

brainpy.math module#

The math module for whole BrainPy ecosystem. This module provides basic mathematical operations, including:

  • numpy-like array operations

  • linear algebra functions

  • random sampling functions

  • discrete fourier transform functions

  • just-in-time compilation for class objects

  • automatic differentiation for class objects

  • dedicated operators for brain dynamics

  • activation functions

  • device switching

  • default type switching

  • and others

Details in the following.

Math Variables#

JaxArray(value)

Multiple-dimensional array in JAX backend.

ndarray

alias of brainpy.math.jaxarray.JaxArray

Variable(value)

The pointer to specify the dynamical variable.

TrainVar(value)

The pointer to specify the trainable variable.

Parameter(value)

The pointer to specify the parameter.

Delay Variables#

AbstractDelay([name])

TimeDelay(delay_target, delay_len[, ...])

Delay variable which has a fixed delay time length.

LengthDelay(delay_target, delay_len[, ...])

Delay variable which has a fixed delay length.

NeuTimeDelay(delay_target, delay_len[, ...])

Neutral Time Delay.

NeuLenDelay(delay_target, delay_len[, ...])

Neutral Length Delay.

JIT Compilation#

The JIT compilation tools for JAX backend.

  1. Just-In-Time compilation is implemented by the ‘jit()’ function

jit(func[, dyn_vars, static_argnames, ...])

JIT (Just-In-Time) compilation for class objects.

Operators#

pre2post_sum(pre_values, post_num, post_ids)

The pre-to-post synaptic summation.

pre2post_prod(pre_values, post_num, post_ids)

The pre-to-post synaptic production.

pre2post_max(pre_values, post_num, post_ids)

The pre-to-post synaptic maximization.

pre2post_min(pre_values, post_num, post_ids)

The pre-to-post synaptic minimization.

pre2post_mean(pre_values, post_num, post_ids)

The pre-to-post synaptic mean computation.

pre2syn(pre_values, pre_ids)

The pre-to-syn computation.

syn2post_sum(syn_values, post_ids, post_num)

The syn-to-post summation computation.

syn2post(syn_values, post_ids, post_num[, ...])

The syn-to-post summation computation.

syn2post_prod(syn_values, post_ids, post_num)

The syn-to-post product computation.

syn2post_max(syn_values, post_ids, post_num)

The syn-to-post maximum computation.

syn2post_min(syn_values, post_ids, post_num)

The syn-to-post minimization computation.

syn2post_mean(syn_values, post_ids, post_num)

The syn-to-post mean computation.

syn2post_softmax(syn_values, post_ids, post_num)

The syn-to-post softmax computation.

pre2post_event_sum(events, pre2post, post_num)

The pre-to-post synaptic computation with event-driven summation.

pre2post_event_prod(events, pre2post, post_num)

The pre-to-post synaptic computation with event-driven production.

sparse_matmul(A, B)

Sparse matrix multiplication.

segment_sum(data, segment_ids[, ...])

rtype

JaxArray

segment_prod(data, segment_ids[, ...])

rtype

JaxArray

segment_max(data, segment_ids[, ...])

rtype

JaxArray

segment_min(data, segment_ids[, ...])

rtype

JaxArray

register_op(op_name, cpu_func[, gpu_func, ...])

Converting the numba-jitted function in a Jax/XLA compatible primitive.

Control Flows#

make_loop(body_fun, dyn_vars[, out_vars, ...])

Make a for-loop function, which iterate over inputs.

make_while(cond_fun, body_fun, dyn_vars)

Make a while-loop function.

make_cond(true_fun, false_fun[, dyn_vars])

Make a condition (if-else) function.

cond(pred, true_fun, false_fun, operands[, ...])

Simple conditional statement (if-else) with instance of Variable.

ifelse(conditions, branches[, operands, ...])

If-else control flows looks like native Pythonic programming.

for_loop(body_fun, dyn_vars, operands[, ...])

for-loop control flow with Variable.

while_loop(body_fun, cond_fun, dyn_vars, ...)

while-loop control flow with Variable.

Automatic Differentiation#

grad(func[, grad_vars, dyn_vars, argnums, ...])

Automatic gradient computation for functions or class objects.

vector_grad(func[, dyn_vars, grad_vars, ...])

Take vector-valued gradients for function func.

jacobian(func[, grad_vars, dyn_vars, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacrev(func[, grad_vars, dyn_vars, argnums, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacfwd(func[, grad_vars, dyn_vars, argnums, ...])

Extending automatic Jacobian (forward-mode) of func to classes.

hessian(func[, dyn_vars, grad_vars, ...])

Hessian of func as a dense array.

Activation Functions#

This module provides commonly used activation functions.

Activation functions are a critical part of the design of a neural network. The choice of activation function in the hidden layer will control how well the network model learns the training dataset. The choice of activation function in the output layer will define the type of predictions the model can make.

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU activation function

hard_swish(x)

Hard SiLU activation function

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis])

Log-Softmax function.

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indicies.

normalize(x[, axis, mean, variance, epsilon])

Normalizes an array by subtracting mean and dividing by sqrt(var).

relu(x)

relu6(x)

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis])

Softmax function.

softplus(x)

Softplus activation function.

silu(x)

SiLU activation function.

swish(x)

SiLU activation function.

selu(x)

Scaled exponential linear unit activation.

identity(x)

tanh(x)

Function#

function([f, nodes, dyn_vars, name])

Comparison Table#

Here is a list of NumPy APIs and its corresponding BrainPy implementations.

- in BrainPy column denotes that implementation is not provided yet. We welcome contributions for these functions.

Multi-dimensional Array#

NumPy

brainpy.math

jax.numpy

numpy.ndarray.all()

brainpy.math.ndarray.all()

jax.numpy.ndarray.all()

numpy.ndarray.any()

brainpy.math.ndarray.any()

jax.numpy.ndarray.any()

numpy.ndarray.argmax()

brainpy.math.ndarray.argmax()

jax.numpy.ndarray.argmax()

numpy.ndarray.argmin()

brainpy.math.ndarray.argmin()

jax.numpy.ndarray.argmin()

numpy.ndarray.argpartition()

brainpy.math.ndarray.argpartition()

jax.numpy.ndarray.argpartition()

numpy.ndarray.argsort()

brainpy.math.ndarray.argsort()

jax.numpy.ndarray.argsort()

numpy.ndarray.astype()

brainpy.math.ndarray.astype()

jax.numpy.ndarray.astype()

numpy.ndarray.byteswap()

brainpy.math.ndarray.byteswap()

-

numpy.ndarray.choose()

brainpy.math.ndarray.choose()

jax.numpy.ndarray.choose()

numpy.ndarray.clip()

brainpy.math.ndarray.clip()

jax.numpy.ndarray.clip()

numpy.ndarray.compress()

brainpy.math.ndarray.compress()

jax.numpy.ndarray.compress()

numpy.ndarray.conj()

brainpy.math.ndarray.conj()

jax.numpy.ndarray.conj()

numpy.ndarray.conjugate()

brainpy.math.ndarray.conjugate()

jax.numpy.ndarray.conjugate()

numpy.ndarray.copy()

brainpy.math.ndarray.copy()

jax.numpy.ndarray.copy()

numpy.ndarray.cumprod()

brainpy.math.ndarray.cumprod()

jax.numpy.ndarray.cumprod()

numpy.ndarray.cumsum()

brainpy.math.ndarray.cumsum()

jax.numpy.ndarray.cumsum()

numpy.ndarray.diagonal()

brainpy.math.ndarray.diagonal()

jax.numpy.ndarray.diagonal()

numpy.ndarray.dot()

brainpy.math.ndarray.dot()

jax.numpy.ndarray.dot()

numpy.ndarray.dump()

-

-

numpy.ndarray.dumps()

-

-

numpy.ndarray.fill()

brainpy.math.ndarray.fill()

-

numpy.ndarray.flatten()

brainpy.math.ndarray.flatten()

jax.numpy.ndarray.flatten()

numpy.ndarray.getfield()

-

-

numpy.ndarray.item()

brainpy.math.ndarray.item()

jax.numpy.ndarray.item()

numpy.ndarray.itemset()

-

-

numpy.ndarray.max()

brainpy.math.ndarray.max()

jax.numpy.ndarray.max()

numpy.ndarray.mean()

brainpy.math.ndarray.mean()

jax.numpy.ndarray.mean()

numpy.ndarray.min()

brainpy.math.ndarray.min()

jax.numpy.ndarray.min()

numpy.ndarray.newbyteorder()

-

-

numpy.ndarray.nonzero()

brainpy.math.ndarray.nonzero()

jax.numpy.ndarray.nonzero()

numpy.ndarray.partition()

-

-

numpy.ndarray.prod()

brainpy.math.ndarray.prod()

jax.numpy.ndarray.prod()

numpy.ndarray.ptp()

brainpy.math.ndarray.ptp()

jax.numpy.ndarray.ptp()

numpy.ndarray.put()

brainpy.math.ndarray.put()

-

numpy.ndarray.ravel()

brainpy.math.ndarray.ravel()

jax.numpy.ndarray.ravel()

numpy.ndarray.repeat()

brainpy.math.ndarray.repeat()

jax.numpy.ndarray.repeat()

numpy.ndarray.reshape()

brainpy.math.ndarray.reshape()

jax.numpy.ndarray.reshape()

numpy.ndarray.resize()

brainpy.math.ndarray.resize()

-

numpy.ndarray.round()

brainpy.math.ndarray.round()

jax.numpy.ndarray.round()

numpy.ndarray.searchsorted()

brainpy.math.ndarray.searchsorted()

jax.numpy.ndarray.searchsorted()

numpy.ndarray.setfield()

-

-

numpy.ndarray.setflags()

-

-

numpy.ndarray.sort()

brainpy.math.ndarray.sort()

jax.numpy.ndarray.sort()

numpy.ndarray.squeeze()

brainpy.math.ndarray.squeeze()

jax.numpy.ndarray.squeeze()

numpy.ndarray.std()

brainpy.math.ndarray.std()

jax.numpy.ndarray.std()

numpy.ndarray.sum()

brainpy.math.ndarray.sum()

jax.numpy.ndarray.sum()

numpy.ndarray.swapaxes()

brainpy.math.ndarray.swapaxes()

jax.numpy.ndarray.swapaxes()

numpy.ndarray.take()

brainpy.math.ndarray.take()

jax.numpy.ndarray.take()

numpy.ndarray.tobytes()

brainpy.math.ndarray.tobytes()

jax.numpy.ndarray.tobytes()

numpy.ndarray.tofile()

-

-

numpy.ndarray.tolist()

brainpy.math.ndarray.tolist()

jax.numpy.ndarray.tolist()

numpy.ndarray.tostring()

-

-

numpy.ndarray.trace()

brainpy.math.ndarray.trace()

jax.numpy.ndarray.trace()

numpy.ndarray.transpose()

brainpy.math.ndarray.transpose()

jax.numpy.ndarray.transpose()

numpy.ndarray.var()

brainpy.math.ndarray.var()

jax.numpy.ndarray.var()

numpy.ndarray.view()

brainpy.math.ndarray.view()

jax.numpy.ndarray.view()

-

brainpy.math.ndarray.block_host_until_ready()

-

-

brainpy.math.ndarray.block_until_ready()

-

-

brainpy.math.ndarray.numpy()

-

-

brainpy.math.ndarray.split()

-

-

brainpy.math.ndarray.tile()

-

-

brainpy.math.ndarray.to_jax()

-

-

brainpy.math.ndarray.to_numpy()

-

-

brainpy.math.ndarray.update()

-

Summary

  • Number of NumPy functions: 56

  • Number of functions covered by brainpy.math: 46

  • Number of functions unique in brainpy.math: 8

  • Number of functions covered by jax.numpy: 42

Array Operations#

NumPy

brainpy.math

jax.numpy

numpy.abs

brainpy.math.abs

jax.numpy.abs

numpy.absolute

brainpy.math.absolute

jax.numpy.absolute

numpy.add

brainpy.math.add

jax.numpy.add

numpy.add_docstring

brainpy.math.add_docstring

-

numpy.add_newdoc

brainpy.math.add_newdoc

-

numpy.add_newdoc_ufunc

brainpy.math.add_newdoc_ufunc

-

numpy.alen

-

-

numpy.all

brainpy.math.all

jax.numpy.all

numpy.allclose

brainpy.math.allclose

jax.numpy.allclose

numpy.alltrue

brainpy.math.alltrue

jax.numpy.alltrue

numpy.amax

brainpy.math.amax

jax.numpy.amax

numpy.amin

brainpy.math.amin

jax.numpy.amin

numpy.angle

brainpy.math.angle

jax.numpy.angle

numpy.any

brainpy.math.any

jax.numpy.any

numpy.append

brainpy.math.append

jax.numpy.append

numpy.apply_along_axis

brainpy.math.apply_along_axis

jax.numpy.apply_along_axis

numpy.apply_over_axes

brainpy.math.apply_over_axes

jax.numpy.apply_over_axes

numpy.arange

brainpy.math.arange

jax.numpy.arange

numpy.arccos

brainpy.math.arccos

jax.numpy.arccos

numpy.arccosh

brainpy.math.arccosh

jax.numpy.arccosh

numpy.arcsin

brainpy.math.arcsin

jax.numpy.arcsin

numpy.arcsinh

brainpy.math.arcsinh

jax.numpy.arcsinh

numpy.arctan

brainpy.math.arctan

jax.numpy.arctan

numpy.arctan2

brainpy.math.arctan2

jax.numpy.arctan2

numpy.arctanh

brainpy.math.arctanh

jax.numpy.arctanh

numpy.argmax

brainpy.math.argmax

jax.numpy.argmax

numpy.argmin

brainpy.math.argmin

jax.numpy.argmin

numpy.argpartition

-

-

numpy.argsort

brainpy.math.argsort

jax.numpy.argsort

numpy.argwhere

brainpy.math.argwhere

jax.numpy.argwhere

numpy.around

brainpy.math.around

jax.numpy.around

numpy.array

brainpy.math.array

jax.numpy.array

numpy.array2string

brainpy.math.array2string

-

numpy.array_equal

brainpy.math.array_equal

jax.numpy.array_equal

numpy.array_equiv

brainpy.math.array_equiv

jax.numpy.array_equiv

numpy.array_repr

brainpy.math.array_repr

jax.numpy.array_repr

numpy.array_split

brainpy.math.array_split

jax.numpy.array_split

numpy.array_str

brainpy.math.array_str

jax.numpy.array_str

numpy.asanyarray

brainpy.math.asanyarray

-

numpy.asarray

brainpy.math.asarray

jax.numpy.asarray

numpy.asarray_chkfinite

-

-

numpy.ascontiguousarray

brainpy.math.ascontiguousarray

-

numpy.asfarray

brainpy.math.asfarray

-

numpy.asfortranarray

-

-

numpy.asmatrix

-

-

numpy.asscalar

brainpy.math.asscalar

-

numpy.atleast_1d

brainpy.math.atleast_1d

jax.numpy.atleast_1d

numpy.atleast_2d

brainpy.math.atleast_2d

jax.numpy.atleast_2d

numpy.atleast_3d

brainpy.math.atleast_3d

jax.numpy.atleast_3d

numpy.average

brainpy.math.average

jax.numpy.average

numpy.bartlett

brainpy.math.bartlett

jax.numpy.bartlett

numpy.base_repr

-

-

numpy.binary_repr

-

-

numpy.bincount

brainpy.math.bincount

jax.numpy.bincount

numpy.bitwise_and

brainpy.math.bitwise_and

jax.numpy.bitwise_and

numpy.bitwise_not

brainpy.math.bitwise_not

jax.numpy.bitwise_not

numpy.bitwise_or

brainpy.math.bitwise_or

jax.numpy.bitwise_or

numpy.bitwise_xor

brainpy.math.bitwise_xor

jax.numpy.bitwise_xor

numpy.blackman

brainpy.math.blackman

jax.numpy.blackman

numpy.block

brainpy.math.block

jax.numpy.block

numpy.bmat

-

-

numpy.broadcast_arrays

brainpy.math.broadcast_arrays

jax.numpy.broadcast_arrays

numpy.broadcast_shapes

brainpy.math.broadcast_shapes

jax.numpy.broadcast_shapes

numpy.broadcast_to

brainpy.math.broadcast_to

jax.numpy.broadcast_to

numpy.busday_count

-

-

numpy.busday_offset

-

-

numpy.byte_bounds

-

-

numpy.can_cast

brainpy.math.can_cast

jax.numpy.can_cast

numpy.cbrt

brainpy.math.cbrt

jax.numpy.cbrt

numpy.ceil

brainpy.math.ceil

jax.numpy.ceil

numpy.choose

brainpy.math.choose

jax.numpy.choose

numpy.clip

brainpy.math.clip

jax.numpy.clip

numpy.column_stack

brainpy.math.column_stack

jax.numpy.column_stack

numpy.common_type

brainpy.math.common_type

-

numpy.compare_chararrays

-

-

numpy.compress

brainpy.math.compress

jax.numpy.compress

numpy.concatenate

brainpy.math.concatenate

jax.numpy.concatenate

numpy.conj

brainpy.math.conj

jax.numpy.conj

numpy.conjugate

brainpy.math.conjugate

jax.numpy.conjugate

numpy.convolve

brainpy.math.convolve

jax.numpy.convolve

numpy.copy

brainpy.math.copy

jax.numpy.copy

numpy.copysign

brainpy.math.copysign

jax.numpy.copysign

numpy.copyto

-

-

numpy.corrcoef

brainpy.math.corrcoef

jax.numpy.corrcoef

numpy.correlate

brainpy.math.correlate

jax.numpy.correlate

numpy.cos

brainpy.math.cos

jax.numpy.cos

numpy.cosh

brainpy.math.cosh

jax.numpy.cosh

numpy.count_nonzero

brainpy.math.count_nonzero

jax.numpy.count_nonzero

numpy.cov

brainpy.math.cov

jax.numpy.cov

numpy.cross

brainpy.math.cross

jax.numpy.cross

numpy.cumprod

brainpy.math.cumprod

jax.numpy.cumprod

numpy.cumproduct

brainpy.math.cumproduct

jax.numpy.cumproduct

numpy.cumsum

brainpy.math.cumsum

jax.numpy.cumsum

numpy.datetime_as_string

-

-

numpy.datetime_data

-

-

numpy.deg2rad

brainpy.math.deg2rad

jax.numpy.deg2rad

numpy.degrees

brainpy.math.degrees

jax.numpy.degrees

numpy.delete

brainpy.math.delete

jax.numpy.delete

numpy.deprecate

-

-

numpy.deprecate_with_doc

-

-

numpy.diag

brainpy.math.diag

jax.numpy.diag

numpy.diag_indices

brainpy.math.diag_indices

jax.numpy.diag_indices

numpy.diag_indices_from

brainpy.math.diag_indices_from

jax.numpy.diag_indices_from

numpy.diagflat

brainpy.math.diagflat

jax.numpy.diagflat

numpy.diagonal

brainpy.math.diagonal

jax.numpy.diagonal

numpy.diff

brainpy.math.diff

jax.numpy.diff

numpy.digitize

brainpy.math.digitize

jax.numpy.digitize

numpy.disp

brainpy.math.disp

-

numpy.divide

brainpy.math.divide

jax.numpy.divide

numpy.divmod

brainpy.math.divmod

jax.numpy.divmod

numpy.dot

brainpy.math.dot

jax.numpy.dot

numpy.dsplit

brainpy.math.dsplit

jax.numpy.dsplit

numpy.dstack

brainpy.math.dstack

jax.numpy.dstack

numpy.ediff1d

brainpy.math.ediff1d

jax.numpy.ediff1d

numpy.einsum

brainpy.math.einsum

jax.numpy.einsum

numpy.einsum_path

brainpy.math.einsum_path

jax.numpy.einsum_path

numpy.empty

brainpy.math.empty

jax.numpy.empty

numpy.empty_like

brainpy.math.empty_like

jax.numpy.empty_like

numpy.equal

brainpy.math.equal

jax.numpy.equal

numpy.exp

brainpy.math.exp

jax.numpy.exp

numpy.exp2

brainpy.math.exp2

jax.numpy.exp2

numpy.expand_dims

brainpy.math.expand_dims

jax.numpy.expand_dims

numpy.expm1

brainpy.math.expm1

jax.numpy.expm1

numpy.extract

brainpy.math.extract

jax.numpy.extract

numpy.eye

brainpy.math.eye

jax.numpy.eye

numpy.fabs

brainpy.math.fabs

jax.numpy.fabs

numpy.fastCopyAndTranspose

-

-

numpy.fill_diagonal

brainpy.math.fill_diagonal

-

numpy.find_common_type

-

-

numpy.fix

brainpy.math.fix

jax.numpy.fix

numpy.flatnonzero

brainpy.math.flatnonzero

jax.numpy.flatnonzero

numpy.flip

brainpy.math.flip

jax.numpy.flip

numpy.fliplr

brainpy.math.fliplr

jax.numpy.fliplr

numpy.flipud

brainpy.math.flipud

jax.numpy.flipud

numpy.float_power

brainpy.math.float_power

jax.numpy.float_power

numpy.floor

brainpy.math.floor

jax.numpy.floor

numpy.floor_divide

brainpy.math.floor_divide

jax.numpy.floor_divide

numpy.fmax

brainpy.math.fmax

jax.numpy.fmax

numpy.fmin

brainpy.math.fmin

jax.numpy.fmin

numpy.fmod

brainpy.math.fmod

jax.numpy.fmod

numpy.format_float_positional

-

-

numpy.format_float_scientific

-

-

numpy.frexp

brainpy.math.frexp

jax.numpy.frexp

numpy.frombuffer

brainpy.math.frombuffer

jax.numpy.frombuffer

numpy.fromfile

brainpy.math.fromfile

jax.numpy.fromfile

numpy.fromfunction

brainpy.math.fromfunction

jax.numpy.fromfunction

numpy.fromiter

brainpy.math.fromiter

jax.numpy.fromiter

numpy.frompyfunc

-

-

numpy.fromregex

-

-

numpy.fromstring

brainpy.math.fromstring

jax.numpy.fromstring

numpy.full

brainpy.math.full

jax.numpy.full

numpy.full_like

brainpy.math.full_like

jax.numpy.full_like

numpy.gcd

brainpy.math.gcd

jax.numpy.gcd

numpy.genfromtxt

brainpy.math.genfromtxt

-

numpy.geomspace

brainpy.math.geomspace

jax.numpy.geomspace

numpy.get_array_wrap

-

-

numpy.get_include

-

-

numpy.get_printoptions

brainpy.math.get_printoptions

jax.numpy.get_printoptions

numpy.getbufsize

-

-

numpy.geterr

-

-

numpy.geterrcall

-

-

numpy.geterrobj

-

-

numpy.gradient

brainpy.math.gradient

jax.numpy.gradient

numpy.greater

brainpy.math.greater

jax.numpy.greater

numpy.greater_equal

brainpy.math.greater_equal

jax.numpy.greater_equal

numpy.hamming

brainpy.math.hamming

jax.numpy.hamming

numpy.hanning

brainpy.math.hanning

jax.numpy.hanning

numpy.heaviside

brainpy.math.heaviside

jax.numpy.heaviside

numpy.histogram

brainpy.math.histogram

jax.numpy.histogram

numpy.histogram2d

brainpy.math.histogram2d

jax.numpy.histogram2d

numpy.histogram_bin_edges

brainpy.math.histogram_bin_edges

jax.numpy.histogram_bin_edges

numpy.histogramdd

brainpy.math.histogramdd

jax.numpy.histogramdd

numpy.hsplit

brainpy.math.hsplit

jax.numpy.hsplit

numpy.hstack

brainpy.math.hstack

jax.numpy.hstack

numpy.hypot

brainpy.math.hypot

jax.numpy.hypot

numpy.i0

brainpy.math.i0

jax.numpy.i0

numpy.imag

brainpy.math.imag

jax.numpy.imag

numpy.in1d

brainpy.math.in1d

jax.numpy.in1d

numpy.indices

brainpy.math.indices

jax.numpy.indices

numpy.info

brainpy.math.info

-

numpy.inner

brainpy.math.inner

jax.numpy.inner

numpy.insert

brainpy.math.insert

jax.numpy.insert

numpy.interp

brainpy.math.interp

jax.numpy.interp

numpy.intersect1d

brainpy.math.intersect1d

jax.numpy.intersect1d

numpy.invert

brainpy.math.invert

jax.numpy.invert

numpy.is_busday

-

-

numpy.isclose

brainpy.math.isclose

jax.numpy.isclose

numpy.iscomplex

brainpy.math.iscomplex

jax.numpy.iscomplex

numpy.iscomplexobj

brainpy.math.iscomplexobj

jax.numpy.iscomplexobj

numpy.isfinite

brainpy.math.isfinite

jax.numpy.isfinite

numpy.isfortran

-

-

numpy.isin

brainpy.math.isin

jax.numpy.isin

numpy.isinf

brainpy.math.isinf

jax.numpy.isinf

numpy.isnan

brainpy.math.isnan

jax.numpy.isnan

numpy.isnat

-

-

numpy.isneginf

brainpy.math.isneginf

jax.numpy.isneginf

numpy.isposinf

brainpy.math.isposinf

jax.numpy.isposinf

numpy.isreal

brainpy.math.isreal

jax.numpy.isreal

numpy.isrealobj

brainpy.math.isrealobj

jax.numpy.isrealobj

numpy.isscalar

brainpy.math.isscalar

jax.numpy.isscalar

numpy.issctype

-

-

numpy.issubclass_

brainpy.math.issubclass_

-

numpy.issubdtype

brainpy.math.issubdtype

jax.numpy.issubdtype

numpy.issubsctype

brainpy.math.issubsctype

jax.numpy.issubsctype

numpy.iterable

brainpy.math.iterable

jax.numpy.iterable

numpy.ix_

brainpy.math.ix_

jax.numpy.ix_

numpy.kaiser

brainpy.math.kaiser

jax.numpy.kaiser

numpy.kron

brainpy.math.kron

jax.numpy.kron

numpy.lcm

brainpy.math.lcm

jax.numpy.lcm

numpy.ldexp

brainpy.math.ldexp

jax.numpy.ldexp

numpy.left_shift

brainpy.math.left_shift

jax.numpy.left_shift

numpy.less

brainpy.math.less

jax.numpy.less

numpy.less_equal

brainpy.math.less_equal

jax.numpy.less_equal

numpy.lexsort

brainpy.math.lexsort

jax.numpy.lexsort

numpy.linspace

brainpy.math.linspace

jax.numpy.linspace

numpy.load

brainpy.math.load

jax.numpy.load

numpy.loads

-

-

numpy.loadtxt

brainpy.math.loadtxt

-

numpy.log

brainpy.math.log

jax.numpy.log

numpy.log10

brainpy.math.log10

jax.numpy.log10

numpy.log1p

brainpy.math.log1p

jax.numpy.log1p

numpy.log2

brainpy.math.log2

jax.numpy.log2

numpy.logaddexp

brainpy.math.logaddexp

jax.numpy.logaddexp

numpy.logaddexp2

brainpy.math.logaddexp2

jax.numpy.logaddexp2

numpy.logical_and

brainpy.math.logical_and

jax.numpy.logical_and

numpy.logical_not

brainpy.math.logical_not

jax.numpy.logical_not

numpy.logical_or

brainpy.math.logical_or

jax.numpy.logical_or

numpy.logical_xor

brainpy.math.logical_xor

jax.numpy.logical_xor

numpy.logspace

brainpy.math.logspace

jax.numpy.logspace

numpy.lookfor

-

-

numpy.mafromtxt

-

-

numpy.mask_indices

brainpy.math.mask_indices

jax.numpy.mask_indices

numpy.mat

-

-

numpy.matmul

brainpy.math.matmul

jax.numpy.matmul

numpy.max

brainpy.math.max

jax.numpy.max

numpy.maximum

brainpy.math.maximum

jax.numpy.maximum

numpy.maximum_sctype

-

-

numpy.may_share_memory

-

-

numpy.mean

brainpy.math.mean

jax.numpy.mean

numpy.median

brainpy.math.median

jax.numpy.median

numpy.meshgrid

brainpy.math.meshgrid

jax.numpy.meshgrid

numpy.min

brainpy.math.min

jax.numpy.min

numpy.min_scalar_type

-

-

numpy.minimum

brainpy.math.minimum

jax.numpy.minimum

numpy.mintypecode

-

-

numpy.mod

brainpy.math.mod

jax.numpy.mod

numpy.modf

brainpy.math.modf

jax.numpy.modf

numpy.moveaxis

brainpy.math.moveaxis

jax.numpy.moveaxis

numpy.msort

brainpy.math.msort

jax.numpy.msort

numpy.multiply

brainpy.math.multiply

jax.numpy.multiply

numpy.nan_to_num

brainpy.math.nan_to_num

jax.numpy.nan_to_num

numpy.nanargmax

brainpy.math.nanargmax

jax.numpy.nanargmax

numpy.nanargmin

brainpy.math.nanargmin

jax.numpy.nanargmin

numpy.nancumprod

brainpy.math.nancumprod

jax.numpy.nancumprod

numpy.nancumsum

brainpy.math.nancumsum

jax.numpy.nancumsum

numpy.nanmax

brainpy.math.nanmax

jax.numpy.nanmax

numpy.nanmean

brainpy.math.nanmean

jax.numpy.nanmean

numpy.nanmedian

brainpy.math.nanmedian

jax.numpy.nanmedian

numpy.nanmin

brainpy.math.nanmin

jax.numpy.nanmin

numpy.nanpercentile

brainpy.math.nanpercentile

jax.numpy.nanpercentile

numpy.nanprod

brainpy.math.nanprod

jax.numpy.nanprod

numpy.nanquantile

brainpy.math.nanquantile

jax.numpy.nanquantile

numpy.nanstd

brainpy.math.nanstd

jax.numpy.nanstd

numpy.nansum

brainpy.math.nansum

jax.numpy.nansum

numpy.nanvar

brainpy.math.nanvar

jax.numpy.nanvar

numpy.ndfromtxt

-

-

numpy.ndim

brainpy.math.ndim

jax.numpy.ndim

numpy.negative

brainpy.math.negative

jax.numpy.negative

numpy.nested_iters

-

-

numpy.nextafter

brainpy.math.nextafter

jax.numpy.nextafter

numpy.nonzero

brainpy.math.nonzero

jax.numpy.nonzero

numpy.not_equal

brainpy.math.not_equal

jax.numpy.not_equal

numpy.obj2sctype

-

-

numpy.ones

brainpy.math.ones

jax.numpy.ones

numpy.ones_like

brainpy.math.ones_like

jax.numpy.ones_like

numpy.outer

brainpy.math.outer

jax.numpy.outer

numpy.packbits

brainpy.math.packbits

jax.numpy.packbits

numpy.pad

brainpy.math.pad

jax.numpy.pad

numpy.partition

-

-

numpy.percentile

brainpy.math.percentile

jax.numpy.percentile

numpy.piecewise

brainpy.math.piecewise

jax.numpy.piecewise

numpy.place

brainpy.math.place

-

numpy.poly

brainpy.math.poly

jax.numpy.poly

numpy.polyadd

brainpy.math.polyadd

jax.numpy.polyadd

numpy.polyder

brainpy.math.polyder

jax.numpy.polyder

numpy.polydiv

brainpy.math.polydiv

-

numpy.polyfit

brainpy.math.polyfit

jax.numpy.polyfit

numpy.polyint

brainpy.math.polyint

jax.numpy.polyint

numpy.polymul

brainpy.math.polymul

jax.numpy.polymul

numpy.polysub

brainpy.math.polysub

jax.numpy.polysub

numpy.polyval

brainpy.math.polyval

jax.numpy.polyval

numpy.positive

brainpy.math.positive

jax.numpy.positive

numpy.power

brainpy.math.power

jax.numpy.power

numpy.printoptions

brainpy.math.printoptions

jax.numpy.printoptions

numpy.prod

brainpy.math.prod

jax.numpy.prod

numpy.product

brainpy.math.product

jax.numpy.product

numpy.promote_types

brainpy.math.promote_types

jax.numpy.promote_types

numpy.ptp

brainpy.math.ptp

jax.numpy.ptp

numpy.put

brainpy.math.put

-

numpy.put_along_axis

-

-

numpy.putmask

brainpy.math.putmask

-

numpy.quantile

brainpy.math.quantile

jax.numpy.quantile

numpy.rad2deg

brainpy.math.rad2deg

jax.numpy.rad2deg

numpy.radians

brainpy.math.radians

jax.numpy.radians

numpy.ravel

brainpy.math.ravel

jax.numpy.ravel

numpy.ravel_multi_index

brainpy.math.ravel_multi_index

jax.numpy.ravel_multi_index

numpy.real

brainpy.math.real

jax.numpy.real

numpy.real_if_close

-

-

numpy.recfromcsv

-

-

numpy.recfromtxt

-

-

numpy.reciprocal

brainpy.math.reciprocal

jax.numpy.reciprocal

numpy.remainder

brainpy.math.remainder

jax.numpy.remainder

numpy.repeat

brainpy.math.repeat

jax.numpy.repeat

numpy.require

-

-

numpy.reshape

brainpy.math.reshape

jax.numpy.reshape

numpy.resize

brainpy.math.resize

jax.numpy.resize

numpy.result_type

brainpy.math.result_type

jax.numpy.result_type

numpy.right_shift

brainpy.math.right_shift

jax.numpy.right_shift

numpy.rint

brainpy.math.rint

jax.numpy.rint

numpy.roll

brainpy.math.roll

jax.numpy.roll

numpy.rollaxis

brainpy.math.rollaxis

jax.numpy.rollaxis

numpy.roots

brainpy.math.roots

jax.numpy.roots

numpy.rot90

brainpy.math.rot90

jax.numpy.rot90

numpy.round

brainpy.math.round

jax.numpy.round

numpy.round_

brainpy.math.round_

jax.numpy.round_

numpy.row_stack

brainpy.math.row_stack

jax.numpy.row_stack

numpy.safe_eval

brainpy.math.safe_eval

-

numpy.save

brainpy.math.save

jax.numpy.save

numpy.savetxt

brainpy.math.savetxt

-

numpy.savez

brainpy.math.savez

jax.numpy.savez

numpy.savez_compressed

brainpy.math.savez_compressed

-

numpy.sctype2char

-

-

numpy.searchsorted

brainpy.math.searchsorted

jax.numpy.searchsorted

numpy.select

brainpy.math.select

jax.numpy.select

numpy.set_numeric_ops

-

-

numpy.set_printoptions

brainpy.math.set_printoptions

jax.numpy.set_printoptions

numpy.set_string_function

-

-

numpy.setbufsize

-

-

numpy.setdiff1d

brainpy.math.setdiff1d

jax.numpy.setdiff1d

numpy.seterr

-

-

numpy.seterrcall

-

-

numpy.seterrobj

-

-

numpy.setxor1d

brainpy.math.setxor1d

jax.numpy.setxor1d

numpy.shape

brainpy.math.shape

jax.numpy.shape

numpy.shares_memory

-

-

numpy.show_config

brainpy.math.show_config

-

numpy.sign

brainpy.math.sign

jax.numpy.sign

numpy.signbit

brainpy.math.signbit

jax.numpy.signbit

numpy.sin

brainpy.math.sin

jax.numpy.sin

numpy.sinc

brainpy.math.sinc

jax.numpy.sinc

numpy.sinh

brainpy.math.sinh

jax.numpy.sinh

numpy.size

brainpy.math.size

jax.numpy.size

numpy.sometrue

brainpy.math.sometrue

jax.numpy.sometrue

numpy.sort

brainpy.math.sort

jax.numpy.sort

numpy.sort_complex

brainpy.math.sort_complex

jax.numpy.sort_complex

numpy.source

-

-

numpy.spacing

-

-

numpy.split

brainpy.math.split

jax.numpy.split

numpy.sqrt

brainpy.math.sqrt

jax.numpy.sqrt

numpy.square

brainpy.math.square

jax.numpy.square

numpy.squeeze

brainpy.math.squeeze

jax.numpy.squeeze

numpy.stack

brainpy.math.stack

jax.numpy.stack

numpy.std

brainpy.math.std

jax.numpy.std

numpy.subtract

brainpy.math.subtract

jax.numpy.subtract

numpy.sum

brainpy.math.sum

jax.numpy.sum

numpy.swapaxes

brainpy.math.swapaxes

jax.numpy.swapaxes

numpy.take

brainpy.math.take

jax.numpy.take

numpy.take_along_axis

brainpy.math.take_along_axis

jax.numpy.take_along_axis

numpy.tan

brainpy.math.tan

jax.numpy.tan

numpy.tensordot

brainpy.math.tensordot

jax.numpy.tensordot

numpy.tile

brainpy.math.tile

jax.numpy.tile

numpy.trace

brainpy.math.trace

jax.numpy.trace

numpy.transpose

brainpy.math.transpose

jax.numpy.transpose

numpy.trapz

brainpy.math.trapz

jax.numpy.trapz

numpy.tri

brainpy.math.tri

jax.numpy.tri

numpy.tril

brainpy.math.tril

jax.numpy.tril

numpy.tril_indices

brainpy.math.tril_indices

jax.numpy.tril_indices

numpy.tril_indices_from

brainpy.math.tril_indices_from

jax.numpy.tril_indices_from

numpy.trim_zeros

brainpy.math.trim_zeros

jax.numpy.trim_zeros

numpy.triu

brainpy.math.triu

jax.numpy.triu

numpy.triu_indices

brainpy.math.triu_indices

jax.numpy.triu_indices

numpy.triu_indices_from

brainpy.math.triu_indices_from

jax.numpy.triu_indices_from

numpy.true_divide

brainpy.math.true_divide

jax.numpy.true_divide

numpy.trunc

brainpy.math.trunc

jax.numpy.trunc

numpy.typename

brainpy.math.typename

-

numpy.union1d

brainpy.math.union1d

jax.numpy.union1d

numpy.unique

brainpy.math.unique

jax.numpy.unique

numpy.unpackbits

brainpy.math.unpackbits

jax.numpy.unpackbits

numpy.unravel_index

brainpy.math.unravel_index

jax.numpy.unravel_index

numpy.unwrap

brainpy.math.unwrap

jax.numpy.unwrap

numpy.vander

brainpy.math.vander

jax.numpy.vander

numpy.var

brainpy.math.var

jax.numpy.var

numpy.vdot

brainpy.math.vdot

jax.numpy.vdot

numpy.vsplit

brainpy.math.vsplit

jax.numpy.vsplit

numpy.vstack

brainpy.math.vstack

jax.numpy.vstack

numpy.where

brainpy.math.where

jax.numpy.where

numpy.who

-

-

numpy.zeros

brainpy.math.zeros

jax.numpy.zeros

numpy.zeros_like

brainpy.math.zeros_like

jax.numpy.zeros_like

-

brainpy.math.as_numpy

-

-

brainpy.math.as_variable

-

-

brainpy.math.clip_by_norm

-

-

brainpy.math.function

-

-

brainpy.math.get_dcomplex

-

-

brainpy.math.get_dfloat

-

-

brainpy.math.get_dint

-

-

brainpy.math.jit

-

-

brainpy.math.pre2post_event_prod

-

-

brainpy.math.pre2post_event_sum

-

-

brainpy.math.pre2post_max

-

-

brainpy.math.pre2post_mean

-

-

brainpy.math.pre2post_min

-

-

brainpy.math.pre2post_prod

-

-

brainpy.math.pre2post_sum

-

-

brainpy.math.pre2syn

-

-

brainpy.math.register_op

-

-

brainpy.math.remove_diag

-

-

brainpy.math.segment_max

-

-

brainpy.math.segment_min

-

-

brainpy.math.segment_prod

-

-

brainpy.math.segment_sum

-

-

brainpy.math.set_dcomplex

-

-

brainpy.math.set_dfloat

-

-

brainpy.math.set_dint

-

-

brainpy.math.sparse_matmul

-

-

brainpy.math.syn2post

-

-

brainpy.math.syn2post_max

-

-

brainpy.math.syn2post_mean

-

-

brainpy.math.syn2post_min

-

-

brainpy.math.syn2post_prod

-

-

brainpy.math.syn2post_softmax

-

-

brainpy.math.syn2post_sum

-

Summary

  • Number of NumPy functions: 399

  • Number of functions covered by brainpy.math: 338

  • Number of functions unique in brainpy.math: 33

  • Number of functions covered by jax.numpy: 314

Linear Algebra#

NumPy

brainpy.math

jax.numpy

numpy.linalg.cholesky

brainpy.math.linalg.cholesky

jax.numpy.linalg.cholesky

numpy.linalg.det

brainpy.math.linalg.det

jax.numpy.linalg.det

numpy.linalg.eig

brainpy.math.linalg.eig

jax.numpy.linalg.eig

numpy.linalg.eigh

brainpy.math.linalg.eigh

jax.numpy.linalg.eigh

numpy.linalg.eigvals

brainpy.math.linalg.eigvals

jax.numpy.linalg.eigvals

numpy.linalg.eigvalsh

brainpy.math.linalg.eigvalsh

jax.numpy.linalg.eigvalsh

numpy.linalg.inv

brainpy.math.linalg.inv

jax.numpy.linalg.inv

numpy.linalg.lstsq

brainpy.math.linalg.lstsq

jax.numpy.linalg.lstsq

numpy.linalg.matrix_power

brainpy.math.linalg.matrix_power

jax.numpy.linalg.matrix_power

numpy.linalg.matrix_rank

brainpy.math.linalg.matrix_rank

jax.numpy.linalg.matrix_rank

numpy.linalg.multi_dot

brainpy.math.linalg.multi_dot

jax.numpy.linalg.multi_dot

numpy.linalg.norm

brainpy.math.linalg.norm

jax.numpy.linalg.norm

numpy.linalg.pinv

brainpy.math.linalg.pinv

jax.numpy.linalg.pinv

numpy.linalg.qr

brainpy.math.linalg.qr

jax.numpy.linalg.qr

numpy.linalg.slogdet

brainpy.math.linalg.slogdet

jax.numpy.linalg.slogdet

numpy.linalg.solve

brainpy.math.linalg.solve

jax.numpy.linalg.solve

numpy.linalg.svd

brainpy.math.linalg.svd

jax.numpy.linalg.svd

numpy.linalg.tensorinv

brainpy.math.linalg.tensorinv

jax.numpy.linalg.tensorinv

numpy.linalg.tensorsolve

brainpy.math.linalg.tensorsolve

jax.numpy.linalg.tensorsolve

Summary

  • Number of NumPy functions: 19

  • Number of functions covered by brainpy.math: 19

  • Number of functions unique in brainpy.math: 0

  • Number of functions covered by jax.numpy: 19

Discrete Fourier Transform#

NumPy

brainpy.math

jax.numpy

numpy.fft.fft

brainpy.math.fft.fft

jax.numpy.fft.fft

numpy.fft.fft2

brainpy.math.fft.fft2

jax.numpy.fft.fft2

numpy.fft.fftfreq

brainpy.math.fft.fftfreq

jax.numpy.fft.fftfreq

numpy.fft.fftn

brainpy.math.fft.fftn

jax.numpy.fft.fftn

numpy.fft.fftshift

brainpy.math.fft.fftshift

jax.numpy.fft.fftshift

numpy.fft.hfft

brainpy.math.fft.hfft

jax.numpy.fft.hfft

numpy.fft.ifft

brainpy.math.fft.ifft

jax.numpy.fft.ifft

numpy.fft.ifft2

brainpy.math.fft.ifft2

jax.numpy.fft.ifft2

numpy.fft.ifftn

brainpy.math.fft.ifftn

jax.numpy.fft.ifftn

numpy.fft.ifftshift

brainpy.math.fft.ifftshift

jax.numpy.fft.ifftshift

numpy.fft.ihfft

brainpy.math.fft.ihfft

jax.numpy.fft.ihfft

numpy.fft.irfft

brainpy.math.fft.irfft

jax.numpy.fft.irfft

numpy.fft.irfft2

brainpy.math.fft.irfft2

jax.numpy.fft.irfft2

numpy.fft.irfftn

brainpy.math.fft.irfftn

jax.numpy.fft.irfftn

numpy.fft.rfft

brainpy.math.fft.rfft

jax.numpy.fft.rfft

numpy.fft.rfft2

brainpy.math.fft.rfft2

jax.numpy.fft.rfft2

numpy.fft.rfftfreq

brainpy.math.fft.rfftfreq

jax.numpy.fft.rfftfreq

numpy.fft.rfftn

brainpy.math.fft.rfftn

jax.numpy.fft.rfftn

Summary

  • Number of NumPy functions: 18

  • Number of functions covered by brainpy.math: 18

  • Number of functions unique in brainpy.math: 0

  • Number of functions covered by jax.numpy: 18

Random Sampling#

NumPy

brainpy.math

jax.numpy

numpy.random.beta

brainpy.math.random.beta

jax.random.beta

numpy.random.binomial

brainpy.math.random.binomial

-

numpy.random.bytes

-

-

numpy.random.chisquare

brainpy.math.random.chisquare

-

numpy.random.choice

brainpy.math.random.choice

jax.random.choice

numpy.random.default_rng

brainpy.math.random.default_rng

-

numpy.random.dirichlet

brainpy.math.random.dirichlet

jax.random.dirichlet

numpy.random.exponential

brainpy.math.random.exponential

jax.random.exponential

numpy.random.f

brainpy.math.random.f

-

numpy.random.gamma

brainpy.math.random.gamma

jax.random.gamma

numpy.random.geometric

brainpy.math.random.geometric

-

numpy.random.get_state

-

-

numpy.random.gumbel

brainpy.math.random.gumbel

jax.random.gumbel

numpy.random.hypergeometric

brainpy.math.random.hypergeometric

-

numpy.random.laplace

brainpy.math.random.laplace

jax.random.laplace

numpy.random.logistic

brainpy.math.random.logistic

jax.random.logistic

numpy.random.lognormal

brainpy.math.random.lognormal

-

numpy.random.logseries

brainpy.math.random.logseries

-

numpy.random.multinomial

brainpy.math.random.multinomial

-

numpy.random.multivariate_normal

brainpy.math.random.multivariate_normal

jax.random.multivariate_normal

numpy.random.negative_binomial

brainpy.math.random.negative_binomial

-

numpy.random.noncentral_chisquare

brainpy.math.random.noncentral_chisquare

-

numpy.random.noncentral_f

brainpy.math.random.noncentral_f

-

numpy.random.normal

brainpy.math.random.normal

jax.random.normal

numpy.random.pareto

brainpy.math.random.pareto

jax.random.pareto

numpy.random.permutation

brainpy.math.random.permutation

jax.random.permutation

numpy.random.poisson

brainpy.math.random.poisson

jax.random.poisson

numpy.random.power

brainpy.math.random.power

-

numpy.random.rand

brainpy.math.random.rand

-

numpy.random.randint

brainpy.math.random.randint

jax.random.randint

numpy.random.randn

brainpy.math.random.randn

-

numpy.random.random

brainpy.math.random.random

-

numpy.random.random_integers

brainpy.math.random.random_integers

-

numpy.random.random_sample

brainpy.math.random.random_sample

-

numpy.random.ranf

brainpy.math.random.ranf

-

numpy.random.rayleigh

brainpy.math.random.rayleigh

-

numpy.random.sample

brainpy.math.random.sample

-

numpy.random.seed

brainpy.math.random.seed

-

numpy.random.set_state

-

-

numpy.random.shuffle

brainpy.math.random.shuffle

jax.random.shuffle

numpy.random.standard_cauchy

brainpy.math.random.standard_cauchy

-

numpy.random.standard_exponential

brainpy.math.random.standard_exponential

-

numpy.random.standard_gamma

brainpy.math.random.standard_gamma

-

numpy.random.standard_normal

brainpy.math.random.standard_normal

-

numpy.random.standard_t

brainpy.math.random.standard_t

-

numpy.random.triangular

brainpy.math.random.triangular

-

numpy.random.uniform

brainpy.math.random.uniform

jax.random.uniform

numpy.random.vonmises

brainpy.math.random.vonmises

-

numpy.random.wald

brainpy.math.random.wald

-

numpy.random.weibull

brainpy.math.random.weibull

-

numpy.random.zipf

brainpy.math.random.zipf

-

-

brainpy.math.random.bernoulli

jax.random.bernoulli

-

brainpy.math.random.call

-

-

brainpy.math.random.categorical

jax.random.categorical

-

brainpy.math.random.index

-

-

brainpy.math.random.jit

-

-

brainpy.math.random.loggamma

jax.random.loggamma

-

brainpy.math.random.maxwell

jax.random.maxwell

-

brainpy.math.random.namedtuple

-

-

brainpy.math.random.orthogonal

jax.random.orthogonal

-

brainpy.math.random.t

jax.random.t

-

brainpy.math.random.truncated_normal

jax.random.truncated_normal

-

brainpy.math.random.weibull_min

jax.random.weibull_min

-

brainpy.math.random.wraps

-

Summary

  • Number of NumPy functions: 51

  • Number of functions covered by brainpy.math: 48

  • Number of functions unique in brainpy.math: 13

  • Number of functions covered by jax.numpy: 16

Setting#

enable_x64([mode])

disable_x64()

set_platform(platform)

Changes platform to CPU, GPU, or TPU.

set_host_device_count(n)

By default, XLA considers all CPU cores as one device.

set_dt(dt)

Set the numerical integrator precision.

get_dt()

Get the numerical integrator precision.

bool_(x)

int_

alias of jax.numpy.int32

float_

alias of jax.numpy.float32

complex_

alias of jax.numpy.complex128

brainpy.math.compat module#

Optimizers#

SGD(*args, **kwargs)

SGD optimizer.

Momentum(*args, **kwargs)

Momentum optimizer.

MomentumNesterov(*args, **kwargs)

MomentumNesterov optimizer.

Adagrad(*args, **kwargs)

Adagrad optimizer.

Adadelta(*args, **kwargs)

Adadelta optimizer.

RMSProp(*args, **kwargs)

RMSProp optimizer.

Adam(*args, **kwargs)

Adam optimizer.

Constant(*args, **kwargs)

Constant scheduler.

ExponentialDecay(*args, **kwargs)

ExponentialDecay scheduler.

InverseTimeDecay(*args, **kwargs)

InverseTimeDecay scheduler.

PolynomialDecay(*args, **kwargs)

PolynomialDecay scheduler.

PiecewiseConstant(*args, **kwargs)

PiecewiseConstant scheduler.

Losses#

cross_entropy_loss(*args, **kwargs)

Cross entropy loss.

l1_loos(*args, **kwargs)

L1 loss.

l2_loss(*args, **kwargs)

L2 loss.

l2_norm(*args, **kwargs)

L2 normal.

huber_loss(*args, **kwargs)

Huber loss.

mean_absolute_error(*args, **kwargs)

mean absolute error loss.

mean_squared_error(*args, **kwargs)

Mean squared error loss.

mean_squared_log_error(*args, **kwargs)

Mean squared log error loss.

brainpy.dyn module#

Dynamics simulation module.

Base Class#

DynamicalSystem([name])

Base Dynamical System class.

Container(*ds_tuple[, name])

Container object which is designed to add other instances of DynamicalSystem.

Network(*ds_tuple[, name])

Base class to model network objects, an alias of Container.

ConstantDelay(size, delay[, dtype, dt])

Class used to model constant delay variables.

NeuGroup(size[, name])

Base class to model neuronal groups.

ConNeuGroup(size[, C, A, V_th, ...])

Base class to model conductance-based neuron group.

TwoEndConn(pre, post[, conn, name])

Base class to model two-end synaptic connections.

Channel(size[, name])

Abstract channel model.

ContainerWrapper(master, **children)

Channel Models#

Base Class#

Ion(size[, name])

Base class for ions.

IonChannel(size[, name])

Base class for ion channels.

Sodium Channel Models#

INa(size[, E, g_max, T, V_sh, method, name])

The sodium current model.

INa_v2(size[, E, g_max, method, name])

Potassium Channel Models#

PotassiumChannel(size[, name])

Base class for potassium channel.

IK_DR(size[, E, g_max, T, T_base, V_sh, ...])

The delayed rectifier potassium channel current.

IK2(size[, E, g_max, method, name])

Calcium Channel Models#

Calcium(size[, method, name])

The base calcium dynamics.

CalciumFixed(size[, E, C, method, name])

Fixed Calcium dynamics.

CalciumDetailed(size[, d, C_rest, tau, C_0, ...])

Dynamical Calcium model.

CalciumAbstract(size[, alpha, beta, ...])

The first-order calcium concentration model.

CalciumChannel(size[, name])

Base class for Calcium ion channels.

IAHP(size[, E, g_max, method, name])

The calcium-dependent potassium current model.

ICaN(size[, E, g_max, phi, method, name])

The calcium-activated non-selective cation channel model.

ICaT(size[, T, T_base_p, T_base_q, g_max, ...])

The low-threshold T-type calcium current model.

ICaT_RE(size[, T, T_base_p, T_base_q, ...])

The low-threshold T-type calcium current model in thalamic reticular nucleus.

ICaHT(size[, T, T_base_p, T_base_q, g_max, ...])

The high-threshold T-type calcium current model.

ICaL(size[, T, T_base_p, T_base_q, g_max, ...])

The L-type calcium channel model.

Ih Channel Models#

IhChannel(size[, name])

Base class for Ih channel models.

Ih(size[, g_max, E, phi, method, name])

The hyperpolarization-activated cation current model.

Leaky Channel Models#

LeakyChannel(size[, name])

Base class for leaky channel.

IL(size[, g_max, E, method, name])

The leakage channel current.

IKL(size[, g_max, E, method, name])

The potassium leak channel current.

Neuron Models#

Biological Models#

HH(size[, ENa, gNa, EK, gK, EL, gL, V_th, ...])

Hodgkin–Huxley neuron model.

MorrisLecar(size[, V_Ca, g_Ca, V_K, g_K, ...])

The Morris-Lecar neuron model.

PinskyRinzelModel(size[, gNa, gK, gCa, ...])

The Pinsky and Rinsel (1994) model.

WangBuzsakiModel(size[, ENa, gNa, EK, gK, ...])

Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model.

Fractional-order Models#

FractionalNeuron(size[, name])

Fractional-order neuron model.

FractionalFHR(size, alpha[, num_memory, a, ...])

The fractional-order FH-R model [1]_.

FractionalIzhikevich(size, alpha, num_step)

Fractional-order Izhikevich model [10]_.

Reduced Models#

LIF(size[, V_rest, V_reset, V_th, tau, ...])

Leaky integrate-and-fire neuron model.

ExpIF(size[, V_rest, V_reset, V_th, V_T, ...])

Exponential integrate-and-fire neuron model.

AdExIF(size[, V_rest, V_reset, V_th, V_T, ...])

Adaptive exponential integrate-and-fire neuron model.

QuaIF(size[, V_rest, V_reset, V_th, V_c, c, ...])

Quadratic Integrate-and-Fire neuron model.

AdQuaIF(size[, V_rest, V_reset, V_th, V_c, ...])

Adaptive quadratic integrate-and-fire neuron model.

GIF(size[, V_rest, V_reset, V_th_inf, ...])

Generalized Integrate-and-Fire model.

Izhikevich(size[, a, b, c, d, V_th, ...])

The Izhikevich neuron model.

HindmarshRose(size[, a, b, c, d, r, s, ...])

Hindmarsh-Rose neuron model.

FHN(size[, a, b, tau, Vth, V_initializer, ...])

FitzHugh-Nagumo neuron model.

Synapse Models#

Biological Models#

AMPA(pre, post, conn[, conn_type, g_max, ...])

AMPA conductance-based synapse model.

GABAa(pre, post, conn[, conn_type, g_max, ...])

GABAa conductance-based synapse model.

Abstract Models#

DeltaSynapse(pre, post, conn[, conn_type, ...])

Voltage Jump Synapse Model, or alias of Delta Synapse Model.

ExpCUBA(pre, post, conn[, conn_type, g_max, ...])

Current-based exponential decay synapse model.

ExpCOBA(pre, post, conn[, conn_type, g_max, ...])

Conductance-based exponential decay synapse model.

DualExpCUBA(pre, post, conn[, conn_type, ...])

Current-based dual exponential synapse model.

DualExpCOBA(pre, post, conn[, conn_type, ...])

Conductance-based dual exponential synapse model.

AlphaCUBA(pre, post, conn[, conn_type, ...])

Current-based alpha synapse model.

AlphaCOBA(pre, post, conn[, conn_type, ...])

Conductance-based alpha synapse model.

NMDA(pre, post, conn[, conn_type, g_max, ...])

Conductance-based NMDA synapse model.

Learning Rule Models#

STP(pre, post, conn[, U, tau_f, tau_d, tau, ...])

Short-term plasticity model.

Rate Models#

Population Models#

Population(size[, name])

FHN(size[, alpha, beta, gamma, delta, ...])

FitzHugh-Nagumo system used in [1]_.

FeedbackFHN(size[, a, b, delay, tau, mu, ...])

FitzHugh-Nagumo model with recurrent neural feedback.

QIF(size[, tau, eta, delta, J, x_ou_mean, ...])

A mean-field model of a quadratic integrate-and-fire neuron population.

StuartLandauOscillator(size[, a, w, ...])

Stuart-Landau model with Hopf bifurcation.

WilsonCowanModel(size[, E_tau, E_a, ...])

Wilson-Cowan population model.

ThresholdLinearModel(size[, tau_e, tau_i, ...])

A threshold linear rate model.

Coupling Models#

DelayCoupling(delay_var, target_var, ...[, ...])

Delay coupling.

DiffusiveCoupling(coupling_var1, ...[, ...])

Diffusive coupling.

AdditiveCoupling(coupling_var, target_var, ...)

Additive coupling.

Helper Models#

Noise Models#

OUProcess(size[, mean, sigma, tau, method, name])

The Ornstein–Uhlenbeck process.

Input Models#

SpikeTimeInput(*args, **kwargs)

Spike Time Input.

SpikeTimeGroup(size, times, indices[, ...])

The input neuron group characterized by spikes emitting at given times.

PoissonInput(*args, **kwargs)

Poisson Group Input.

PoissonGroup(size, freqs[, seed, keep_size, ...])

Poisson Neuron Group.

Runners#

DSRunner(target[, inputs, dt])

The runner for dynamical systems.

ReportRunner(target[, inputs, jit, dt])

The runner provides convenient interface for debugging.

StructRunner(target, *args, **kwargs)

The runner with the structural for-loop.

brainpy.nn module#

Neural Networks (nn)

Base Classes#

This module provide basic Node class for whole brainpy.nn system.

  • brainpy.nn.Node: The fundamental class representing the node or the element.

  • brainpy.nn.RecurrentNode: The recurrent node which has a self-connection.

  • brainpy.nn.Network: The network model which is composed of multiple node elements. Once the Network instance receives a node operation, the wrapped elements, the new elements, and their connection edges will be formed as another Network instance. This means brainpy.nn.Network is only used to pack element nodes. It will be never be an element node.

  • brainpy.nn.FrozenNetwork: The whole network which can be represented as a basic elementary node when composing a larger network (TODO).

Node([name, input_shape, trainable])

Basic Node class for neural network building in BrainPy.

Network([nodes, ff_edges, fb_edges])

Basic Network class for neural network building in BrainPy.

RecurrentNode([name, input_shape, ...])

Basic class for recurrent node.

FrozenNetwork([nodes, ff_edges, fb_edges])

A FrozenNetwork is a Network that can not be linked to other nodes or networks.

Node Operations#

This module provides basic operations for constructing node graphs.

It supports the following operations:

  1. feedforward connection: “>>”, “>>=”

  2. feedback connection: “<<”, “<<=”

  3. merge two nodes: “&”, “&=”

  4. select subsets of one node: “[:]”

  5. concatenate a sequence of nodes: “[node1, node2, …]”, “(node1, node2, …)”

  6. wrap a set of nodes: “{node1, node2, …}”

However, all operations should satisfy the following assumptions:

  1. Feedback connection of (node1, node2) should have a feedforward path from node2 to node1.

  2. Feedforward or feedback connections cannot generate a cycle.

  3. Cannot concatenate multiple receiver nodes, e.g., a >> [b, c] is forbidden, but a >> {b, c} is allowed.

ff_connect(senders, receivers[, inplace, ...])

Connect two sequences of Node instances to form a brainpy.nn.base.Network instance.

fb_connect(senders, receivers[, inplace, ...])

Create a feedback connection from sender node to receiver node.

merge(node, *other_nodes[, inplace, name, ...])

Merge different Node or brainpy.nn.base.Network instances into a single brainpy.nn.base.Network instance.

select(node, index[, name])

concatenate(nodes[, axis, name])

Node Graph Tools#

This module provides basic tool for graphs, including

  • detect the senders and receivers in the network graph,

  • find input and output nodes in a given graph,

  • detect the cycle in the graph,

  • detect the path between two nodes.

find_senders_and_receivers(edges)

Find all senders and receivers in the given graph.

find_entries_and_exits(nodes, ff_edges[, ...])

Find input nodes and output nodes.

detect_cycle(nodes, edges)

Detect whether a cycle exists in the defined graph.

detect_path(from_node, to_node, edges[, method])

Detect whether there is a path exist in the defined graph from from_node to to_node.

Runners and Trainers#

This module provides various running and training algorithms for various neural networks.

The supported training algorithms include

  • offline training methods, like ridge regression, linear regression, etc.

  • online training methods, like recursive least squares (RLS, or Force Learning), least mean squares (LMS), etc.

  • back-propagation learning method

  • and others

The supported neural networks include

  • reservoir computing networks,

  • artificial recurrent neural networks,

  • and others.

Base RNN Runner#

RNNRunner(target[, jit])

Structural Runner for Recurrent Neural Networks.

Base RNN Trainer#

RNNTrainer(target, **kwargs)

Structural Trainer for Models with Recurrent Dynamics.

Online RNN Trainer#

OnlineTrainer(target[, fit_method])

Online trainer for models with recurrent dynamics.

ForceTrainer(target[, alpha])

Force learning.

Offline RNN Trainer#

OfflineTrainer(target[, fit_method])

Offline trainer for models with recurrent dynamics.

RidgeTrainer(target[, beta])

Trainer of ridge regression, also known as regression with Tikhonov regularization.

Back-propagation Trainer#

BPTT(target, loss[, optimizer, ...])

The trainer implementing back propagation through time (BPTT) algorithm for recurrent neural networks.

BPFF(target, **kwargs)

The trainer implementing back propagation algorithm for feedforward neural networks.

Training Algorithms#

Online Training Algorithms#

get_supported_online_methods()

Get all supported online training methods.

register_online_method(name, method)

Register a new oneline learning method.

OnlineAlgorithm([name])

Base class for online training algorithm.

ForceLearning([alpha, name])

RLS([alpha, name])

The recursive least squares (RLS).

LMS([alpha, name])

The least mean squares (LMS).

Offline Training Algorithms#

get_supported_offline_methods()

Get all supported offline training methods.

register_offline_method(name, method)

Register a new offline learning method.

OfflineAlgorithm([name])

Base class for offline training algorithm.

RidgeRegression([beta, name])

Training algorithm of ridge regression.

LinearRegression([name])

Training algorithm of least-square regression.

Data Types#

DataType()

Base class for data type.

SingleData()

Pass the only one data into the node.

MultipleData([return_type])

Pass a list/tuple of data into the node.

Nodes: basic#

Activation([activation, fun_setting, ...])

Activation node.

DenseMD(num_unit[, weight_initializer, ...])

A linear transformation applied over the last dimension of the input.

Dense(num_unit[, weight_initializer, ...])

A linear transformation.

Input(input_shape[, trainable, name])

The input node.

Concat([axis, trainable])

Concatenate multiple inputs into one.

Select(index[, trainable])

Select a subset of the given input.

Reshape(shape[, trainable])

Reshape the input tensor to another tensor.

Summation([trainable])

Sum all input tensors into one.

Nodes: artificial neural network#

Artificial neural network (ANN) nodes

GeneralConv(out_channels, kernel_size[, ...])

Applies a convolution to the inputs.

Conv1D(out_channels, kernel_size, **kwargs)

Conv2D(out_channels, kernel_size, **kwargs)

Conv3D(out_channels, kernel_size, **kwargs)

Dropout(prob[, seed])

A layer that stochastically ignores a subset of inputs each training step.

VanillaRNN(num_unit[, state_initializer, ...])

Basic fully-connected RNN core.

GRU(num_unit[, wi_initializer, ...])

Gated Recurrent Unit.

LSTM(num_unit[, wi_initializer, ...])

Long short-term memory (LSTM) RNN core.

Pool(init_v, reduce_fn, window_shape, ...)

MaxPool(window_shape[, strides, padding])

Pools the input by taking the maximum over a window.

AvgPool(window_shape[, strides, padding])

Pools the input by taking the average over a window.

MinPool(window_shape[, strides, padding])

Pools the input by taking the minimum over a window.

BatchNorm(axis[, epsilon, use_bias, ...])

Batch Normalization node.

BatchNorm1d([axis])

1-D batch normalization.

BatchNorm2d([axis])

2-D batch normalization.

BatchNorm3d([axis])

3-D batch normalization.

GroupNorm([num_groups, group_size, epsilon, ...])

Group normalization layer.

LayerNorm([epsilon, use_bias, use_scale, ...])

Layer normalization (https://arxiv.org/abs/1607.06450).

InstanceNorm([epsilon, use_bias, use_scale, ...])

Instance normalization layer.

Nodes: reservoir computing#

Reservoir computing (RC) nodes

LinearReadout(num_unit, **kwargs)

Linear readout node.

NVAR(delay[, order, stride, constant, trainable])

Nonlinear vector auto-regression (NVAR) node.

Reservoir(num_unit[, leaky_rate, ...])

Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_.

brainpy.analysis module#

This module provides analysis tools for differential equations.

  • The symbolic module use SymPy symbolic inference to make analysis of low-dimensional dynamical system (only sypport ODEs).

  • The numeric module use numerical optimization function to make analysis of high-dimensional dynamical system (support ODEs and discrete systems).

  • The continuation module is the analysis package with numerical continuation methods.

  • Moreover, we provide several useful functions in stability module which may help your dynamical system analysis.

Details in the following.

Low-dimensional Analyzers#

Bifurcation1D(model, target_pars, target_vars)

Bifurcation analysis of 1D system.

Bifurcation2D(model, target_pars, target_vars)

Bifurcation analysis of 2D system.

FastSlow1D(model, fast_vars, slow_vars[, ...])

FastSlow2D(model, fast_vars, slow_vars[, ...])

PhasePlane1D(model, target_vars[, ...])

Phase plane analyzer for 1D dynamical system.

PhasePlane2D(model, target_vars[, ...])

Phase plane analyzer for 2D dynamical system.

High-dimensional Analyzers#

SlowPointFinder(f_cell[, f_type, ...])

Find fixed/slow points by numerical optimization.

Stability Analysis#

get_1d_stability_types()

Get the stability types of 1D system.

get_2d_stability_types()

Get the stability types of 2D system.

get_3d_stability_types()

Get the stability types of 3D system.

stability_analysis(derivatives)

Stability analysis of fixed points for low-dimensional system.

plot_scheme

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2).

CENTER_MANIFOLD

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

SADDLE_NODE

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

STABLE_POINT_1D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_POINT_1D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

CENTER_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

STABLE_NODE_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

STABLE_FOCUS_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

STABLE_STAR_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

STABLE_DEGENERATE_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_NODE_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_FOCUS_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_STAR_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_DEGENERATE_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

UNSTABLE_LINE_2D

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

brainpy.integrators module#

This module provides numerical solvers for various differential equations, including:

  • ordinary differential equations (ODEs)

  • stochastic differential equations (SDEs)

  • fractional differential equations (FDEs)

  • delay differential equations (DDEs)

Details please see the following.

Integrator Runner#

IntegratorRunner(target[, inits, args, ...])

Structural runner for numerical integrators in brainpy.

Joint Equation#

JointEq(*eqs)

Make a joint equation from multiple derivation functions.

Numerical Methods for ODEs#

Numerical methods for ordinary differential equations (ODEs).

Base Integrator#

ODEIntegrator(f[, var_type, dt, name, ...])

Numerical Integrator for Ordinary Differential Equations (ODEs).

Generic Functions#

odeint([f, method, var_type, dt, name, ...])

Numerical integration for ODEs.

set_default_odeint(method)

Set the default ODE numerical integrator method for differential equations.

get_default_odeint()

Get the default ODE numerical integrator method.

register_ode_integrator(name, integrator)

Register a new ODE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Explicit Runge-Kutta Methods#

This module provides explicit Runge-Kutta methods for ODEs.

Given an initial value problem specified as:

\[\frac{dy}{dt}=f(t,y),\quad y(t_{0})=y_{0}.\]

Let the step-size \(h > 0\).

Then, the general schema of explicit Runge–Kutta methods is 1:

\[y_{n+1}=y_{n}+h\sum _{i=1}^{s}b_{i}k_{i},\]

where

\[\begin{split}\begin{aligned} k_{1}&=f(t_{n},y_{n}),\\ k_{2}&=f(t_{n}+c_{2}h,y_{n}+h(a_{21}k_{1})),\\ k_{3}&=f(t_{n}+c_{3}h,y_{n}+h(a_{31}k_{1}+a_{32}k_{2})),\\ &\\ \vdots \\ k_{s}&=f(t_{n}+c_{s}h,y_{n}+h(a_{s1}k_{1}+a_{s2}k_{2}+\cdots +a_{s,s-1}k_{s-1})). \end{aligned}\end{split}\]

To specify a particular method, one needs to provide the integer \(s\) (the number of stages), and the coefficients \(a_{ij}\) (for \(1 \le j < i \le s\)), \(b_i\) (for \(i = 1, 2, \cdots, s\)) and \(c_i\) (for \(i = 2, 3, \cdots, s\)).

The matrix \([a_{ij}]\) is called the Runge–Kutta matrix, while the \(b_i\) and \(c_i\) are known as the weights and the nodes. These data are usually arranged in a mnemonic device, known as a Butcher tableau (named after John C. Butcher):

\[\begin{split}\begin{array}{c|llll} 0 & & & & & \\ c_{2} & a_{21} & & & & \\ c_{3} & a_{31} & a_{32} & & & \\ \vdots & \vdots & & \ddots & \\ c_{s} & a_{s 1} & a_{s 2} & \cdots & a_{s, s-1} \\ \hline & b_{1} & b_{2} & \cdots & b_{s-1} & b_{s} \end{array}\end{split}\]

A Taylor series expansion shows that the Runge–Kutta method is consistent if and only if

\[\sum _{i=1}^{s}b_{i}=1.\]

Another popular condition for determining coefficients is:

\[\sum_{j=1}^{i-1}a_{ij}=c_{i}{\text{ for }}i=2,\ldots ,s.\]

More details please see references 2 3 4.

1

Press, W. H., B. P. Flannery, S. A. Teukolsky, and W. T. Vetterling. “Section 17.1 Runge-Kutta Method.” Numerical Recipes: The Art of Scientific Computing (2007).

2

https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods

3

Butcher, John Charles. Numerical methods for ordinary differential equations. John Wiley & Sons, 2016.

4

Iserles, A., 2009. A first course in the numerical analysis of differential equations (No. 44). Cambridge university press.

ExplicitRKIntegrator(f[, var_type, dt, ...])

Explicit Runge–Kutta methods for ordinary differential equation.

Euler(f[, var_type, dt, name, show_code, ...])

The Euler method for ODEs.

MidPoint(f[, var_type, dt, name, show_code, ...])

Explicit midpoint method for ODEs.

Heun2(f[, var_type, dt, name, show_code, ...])

Heun's method for ODEs.

Ralston2(f[, var_type, dt, name, show_code, ...])

Ralston's method for ODEs.

RK2(f[, beta, var_type, dt, name, ...])

Generic second order Runge-Kutta method for ODEs.

RK3(f[, var_type, dt, name, show_code, ...])

Classical third-order Runge-Kutta method for ODEs.

Heun3(f[, var_type, dt, name, show_code, ...])

Heun's third-order method for ODEs.

Ralston3(f[, var_type, dt, name, show_code, ...])

Ralston's third-order method for ODEs.

SSPRK3(f[, var_type, dt, name, show_code, ...])

Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

RK4(f[, var_type, dt, name, show_code, ...])

Classical fourth-order Runge-Kutta method for ODEs.

Ralston4(f[, var_type, dt, name, show_code, ...])

Ralston's fourth-order method for ODEs.

RK4Rule38(f[, var_type, dt, name, ...])

3/8-rule fourth-order method for ODEs.

Adaptive Runge-Kutta Methods#

This module provides adaptive Runge-Kutta methods for ODEs.

Adaptive methods are designed to produce an estimate of the local truncation error of a single Runge–Kutta step. This is done by having two methods, one with order \(p\) and one with order \(p-1\). These methods are interwoven, i.e., they have common intermediate steps. Thanks to this, estimating the error has little or negligible computational cost compared to a step with the higher-order method.

During the integration, the step size is adapted such that the estimated error stays below a user-defined threshold: If the error is too high, a step is repeated with a lower step size; if the error is much smaller, the step size is increased to save time. This results in an (almost) optimal step size, which saves computation time. Moreover, the user does not have to spend time on finding an appropriate step size.

The lower-order step is given by

\[y_{n+1}^{*}=y_{n}+h\sum _{i=1}^{s}b_{i}^{*}k_{i},\]

where \(k_{i}\) are the same as for the higher-order method. Then the error is

\[e_{n+1}=y_{n+1}-y_{n+1}^{*}=h\sum _{i=1}^{s}(b_{i}-b_{i}^{*})k_{i},\]

which is (\(O(h^{p}\)).

The Butcher tableau for this kind of method is extended to give the values of \(b_{i}^{*}\):

\[\begin{split}\begin{array}{c|llll} 0 & & & & & \\ c_{2} & a_{21} & & & & \\ c_{3} & a_{31} & a_{32} & & & \\ \vdots & \vdots & & \ddots & \\ c_{s} & a_{s 1} & a_{s 2} & \cdots & a_{s, s-1} \\ \hline & b_{1} & b_{2} & \cdots & b_{s-1} & b_{s} \\ & b_{1}^{*} & b_{2}^{*} & \cdots & b_{s-1}^{*} & b_{s}^{*} \end{array}\end{split}\]

More details please check 1 2 3.

1

https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods

2

Press, W.H., Press, W.H., Flannery, B.P., Teukolsky, S.A., Vetterling, W.T., Flannery, B.P. and Vetterling, W.T., 1989. Numerical recipes in Pascal: the art of scientific computing (Vol. 1). Cambridge university press.

3

Press, W. H., & Teukolsky, S. A. (1992). Adaptive Stepsize Runge‐Kutta Integration. Computers in Physics, 6(2), 188-191.

AdaptiveRKIntegrator(f[, var_type, dt, ...])

Adaptive Runge-Kutta method for ordinary differential equations.

RKF12(f[, var_type, dt, name, adaptive, ...])

The Fehlberg RK1(2) method for ODEs.

RKF45(f[, var_type, dt, name, adaptive, ...])

The Runge–Kutta–Fehlberg method for ODEs.

DormandPrince(f[, var_type, dt, name, ...])

The Dormand–Prince method for ODEs.

CashKarp(f[, var_type, dt, name, adaptive, ...])

The Cash–Karp method for ODEs.

BogackiShampine(f[, var_type, dt, name, ...])

The Bogacki–Shampine method for ODEs.

HeunEuler(f[, var_type, dt, name, adaptive, ...])

The Heun–Euler method for ODEs.

Exponential Integrators#

This module provides exponential integrators for ODEs.

Exponential integrators are a large class of methods from numerical analysis is based on the exact integration of the linear part of the initial value problem. Because the linear part is integrated exactly, this can help to mitigate the stiffness of a differential equation.

We consider initial value problems of the form,

\[u'(t)=f(u(t)),\qquad u(t_{0})=u_{0},\]

which can be decomposed of

\[u'(t)=Lu(t)+N(u(t)),\qquad u(t_{0})=u_{0},\]

where \(L={\frac {\partial f}{\partial u}}\) (the Jacobian of f) is composed of linear terms, and \(N=f(u)-Lu\) is composed of the non-linear terms.

This procedure enjoys the advantage, in each step, that \({\frac {\partial N_{n}}{\partial u}}(u_{n})=0\). This considerably simplifies the derivation of the order conditions and improves the stability when integrating the nonlinearity \(N(u(t))\).

Exact integration of this problem from time 0 to a later time \(t\) can be performed using matrix exponentials to define an integral equation for the exact solution:

\[u(t)=e^{Lt}u_{0}+\int _{0}^{t}e^{L(t-\tau )}N\left(t+\tau, u\left(\tau \right)\right)\,d\tau .\]

This representation of the exact solution is also called as variation-of-constant formula. In the case of \(N\equiv 0\), this formulation is the exact solution to the linear differential equation.

Exponential Rosenbrock methods

Exponential Rosenbrock methods were shown to be very efficient in solving large systems of stiff ODEs. Applying the variation-of-constants formula gives the exact solution at time \(t_{n+1}\) with the numerical solution \(u_n\) as

()#\[u(t_{n+1})=e^{h_{n}L}u(t_{n})+\int _{0}^{h_{n}}e^{(h_{n}-\tau )L}N(t_n+\tau, u(t_{n}+\tau ))d\tau .\]

where \(h_n=t_{n+1}-t_n\).

The idea now is to approximate the integral in (1) by some quadrature rule with nodes \(c_{i}\) and weights \(b_{i}(h_{n}L)\) (\(1\leq i\leq s\)). This yields the following class of s-stage explicit exponential Rosenbrock methods:

\[\begin{split}\begin{align} U_{ni}=&e^{c_{i}h_{n}L}u_n+h_{n}\sum_{j=1}^{i-1}a_{ij}(h_{n}L)N(U_{nj}), \\ u_{n+1}=&e^{h_{n}L}u_n+h_{n}\sum_{i=1}^{s}b_{i}(h_{n}L)N(U_{ni}) \end{align}\end{split}\]

where \(U_{ni}\approx u(t_{n}+c_{i}h_{n})\).

The coefficients \(a_{ij}(z),b_{i}(z)\) are usually chosen as linear combinations of the entire functions \(\varphi _{k}(c_{i}z),\varphi _{k}(z)\), respectively, where

\[\begin{split}\begin{align} \varphi _{k}(z)=&\int _{0}^{1}e^{(1-\theta )z}{\frac {\theta ^{k-1}}{(k-1)!}}d\theta ,\quad k\geq 1, \\ \varphi _{0}(z)=&e^{z},\\ \varphi _{k+1}(z)=&{\frac {\varphi_{k}(z)-\varphi _{k}(0)}{z}},\ k\geq 0. \end{align}\end{split}\]

By introducing the difference \(D_{ni}=N(U_{ni})-N(u_{n})\), they can be reformulated in a more efficient way for implementation as

\[\begin{split}\begin{align} U_{ni}=&u_{n}+c_{i}h_{n}\varphi _{1}(c_{i}h_{n}L)f(u_{n})+h_{n}\sum _{j=2}^{i-1}a_{ij}(h_{n}L)D_{nj}, \\ u_{n+1}=&u_{n}+h_{n}\varphi _{1}(h_{n}L)f(u_{n})+h_{n}\sum _{i=2}^{s}b_{i}(h_{n}L)D_{ni}. \end{align}\end{split}\]

where \(\varphi_{1}(z)=\frac{e^z-1}{z}\).

In order to implement this scheme with adaptive step size, one can consider, for the purpose of local error estimation, the following embedded methods

\[{\bar {u}}_{n+1}=u_{n}+h_{n}\varphi _{1}(h_{n}L)f(u_{n})+h_{n}\sum _{i=2}^{s}{\bar {b}}_{i}(h_{n}L)D_{ni},\]

which use the same stages \(U_{ni}\) but with weights \({\bar {b}}_{i}\).

For convenience, the coefficients of the explicit exponential Rosenbrock methods together with their embedded methods can be represented by using the so-called reduced Butcher tableau as follows:

\[\begin{split}\begin{array}{c|ccccc} c_{2} & & & & & \\ c_{3} & a_{32} & & & & \\ \vdots & \vdots & & \ddots & & \\ c_{s} & a_{s 2} & a_{s 3} & \cdots & a_{s, s-1} \\ \hline & b_{2} & b_{3} & \cdots & b_{s-1} & b_{s} \\ & \bar{b}_{2} & \bar{b}_{3} & \cdots & \bar{b}_{s-1} & \bar{b}_{s} \end{array}\end{split}\]
1

https://en.wikipedia.org/wiki/Exponential_integrator

2

Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.

ExponentialEuler(f[, var_type, dt, name, ...])

Exponential Euler method using automatic differentiation.

Error Analysis of Numerical Methods#

In order to identify the essential properties of numerical methods, we define basic notions 1.

For the given ODE system

\[\frac{dy}{dt}=f(t,y),\quad y(t_{0})=y_{0},\]

we define \(y(t_n)\) as the solution of IVP evaluated at \(t=t_n\), and \(y_n\) is a numerical approximation of \(y(t_n)\) at the same location by a generic explicit numerical scheme (no matter explicit, implicit or multi-step scheme):

\[\begin{align} y_{n+1} = y_n + h \phi(t_n,y_n,h), \tag{2} \end{align}\]

where \(h\) is the discretization step for \(t\), i.e., \(h=t_{n+1}-t_n\), and \(\phi(t_n,y_n,h)\) is the increment function. We say that the defined numerical scheme is consistent if \(\lim_{h\to0} \phi(t,y,h) = \phi(t,y,0) = f(t,y)\).

Then, the approximation error is defined as

\[e_n = y(t_n) - y_n.\]

The absolute error is defined as

\[|e_n| = |y(t_n) - y_n|.\]

The relative error is defined as

\[r_n =\frac{|y(t_n) - y_n|}{|y(t_n)|}.\]

The exact differential operator is defined as

\[\begin{align} L_e(y) = y' - f(t,y) = 0 \end{align}\]

The approximate differential operator is defined as

\[\begin{align} L_a(y_n) = y(t_{n+1}) - [y_n + \phi(t_n,y_n,h)]. \end{align}\]

Finally, the local truncation error (LTE) is defined as

\[\begin{align} \tau_n = \frac{1}{h} L_a(y(x_n)). \end{align}\]

In practice, the evaluation of the exact solution for different \(t\) around \(t_n\) (required by \(L_a\)) is performed using a Taylor series expansion.

Finally, we can state that a scheme is \(p\)-th order accurate by examining its LTE and observing its leading term

\[\begin{align} \tau_n = C h^p + H.O.T., \end{align}\]

where \(C\) is a constant, independent of \(h\), and \(H.O.T.\) are the higher order terms of the LTE.

Example: LTE for Euler’s scheme

Consider the IVP defined by \(y' = \lambda y\), with initial condition \(y(0)=1\).

The approximation operator for Euler’s scheme is

\[\begin{align} L^{euler}_a = y(t_{n+1}) - [y_n + h \lambda y_n], \end{align}\]

then the LTE can be computed by

\[\begin{split}\begin{align} \tau_n = & \frac{1}{h}\left\{ L_a(y(t_n))\right\} = \frac{1}{h}\left\{ y(t_{n+1}) - [y(t_n) + h \lambda y(t_n)]\right\}, \\ = & \frac{1}{h}\left\{ y(t_n) + h y'(t_n) + \frac{h^2}{2} y''(t_n) + \ldots + \frac{1}{p!} h^p y^{(p)}(t_n) - y(t_n) - h \lambda y(t_n) \right\} \\ = & \frac{1}{2} h y''(t_n) + \ldots + \frac{1}{p!} h^{p-1} y^{(p)}(t_n) \\ \approx & \frac{1}{2} h y''(t_n), \end{align}\end{split}\]

where we assume \(y_n = y(t_n)\).

1

https://folk.ntnu.no/leifh/teaching/tkt4140/._main022.html

Numerical Methods for SDEs#

Numerical methods for stochastic differential equations.

Base Integrator#

SDEIntegrator(f, g[, dt, name, show_code, ...])

SDE Integrator.

Generic Functions#

sdeint([f, g, method, dt, name, show_code, ...])

Numerical integration for SDEs.

set_default_sdeint(method)

Set the default SDE numerical integrator method for differential equations.

get_default_sdeint()

Get the default SDE numerical integrator method.

register_sde_integrator(name, integrator)

Register a new SDE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Normal Methods#

Euler(f, g[, dt, name, show_code, var_type, ...])

Heun(f, g[, dt, name, show_code, var_type, ...])

Milstein(f, g[, dt, name, show_code, ...])

ExponentialEuler(f, g[, dt, name, ...])

First order, explicit exponential Euler method.

SRK methods for scalar Wiener process#

SRK1W1(f, g[, dt, name, show_code, ...])

Order 2.0 weak SRK methods for SDEs with scalar Wiener process.

SRK2W1(f, g[, dt, name, show_code, ...])

Order 1.5 Strong SRK Methods for SDEs with Scalar Noise.

KlPl(f, g[, dt, name, show_code, var_type, ...])

Numerical Methods for FDEs#

Numerical methods for stochastic differential equations.

Base Integrator#

FDEIntegrator(f, alpha, num_step[, dt, ...])

Numerical integrator for fractional differential equations (FEDs).

Generic Functions#

fdeint(alpha, num_step, inits[, f, method, ...])

Numerical integration for FDEs.

set_default_fdeint(method)

Set the default ODE numerical integrator method for differential equations.

get_default_fdeint()

Get the default ODE numerical integrator method.

register_fde_integrator(name, integrator)

Register a new ODE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Methods for Caputo Fractional Derivative#

This module provides numerical methods for integrating Caputo fractional derivative equations.

CaputoEuler(f, alpha, num_step, inits[, dt, ...])

One-step Euler method for Caputo fractional differential equations.

CaputoL1Schema(f, alpha, num_step, inits[, ...])

The L1 scheme method for the numerical approximation of the Caputo fractional-order derivative equations [3]_.

Methods for Riemann-Liouville Fractional Derivative#

This module provides numerical solvers for Grünwald–Letnikov derivative FDEs.

GLShortMemory(f, alpha, inits[, num_memory, ...])

Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_.

brainpy.datasets module#

Chaotic Systems#

henon_map_series(num_step[, a, b, inits, ...])

The Hénon map time series.

logistic_map_series(num_step[, mu, inits, ...])

The logistic map time series.

modified_lu_chen_series(duration[, dt, a, ...])

Modified Lu Chen attractor.

mackey_glass_series(duration[, dt, beta, ...])

The Mackey-Glass time series.

rabinovich_fabrikant_series(duration[, dt, ...])

Rabinovich-Fabrikant equations.

chen_chaotic_series(duration[, dt, a, b, c, ...])

Chen attractor.

lu_chen_chaotic_series(duration[, dt, a, c, ...])

Lu Chen attractor.

chua_chaotic_series(duration[, dt, alpha, ...])

Chua’s system.

modified_chua_series(duration[, dt, alpha, ...])

Modified Chua chaotic attractor.

lorenz_series(duration[, dt, sigma, beta, ...])

The Lorenz system.

modified_Lorenz_series(duration[, dt, a, b, ...])

Modified Lorenz chaotic system.

double_scroll_series(duration[, dt, R1, R2, ...])

Double-scroll electronic circuit attractor.

PWL_duffing_series(duration[, dt, e, m0, ...])

PWL Duffing chaotic attractor.

brainpy.inputs module#

This module provides various methods to form current inputs. You can access them through brainpy.inputs.XXX.

section_input(values, durations[, dt, ...])

Format an input current with different sections.

constant_input(I_and_duration[, dt])

Format constant input in durations.

constant_current(I_and_duration[, dt])

Format constant input in durations.

spike_input(sp_times, sp_lens, sp_sizes, ...)

Format current input like a series of short-time spikes.

spike_current(sp_times, sp_lens, sp_sizes, ...)

Format current input like a series of short-time spikes.

ramp_input(c_start, c_end, duration[, ...])

Get the gradually changed input current.

ramp_current(c_start, c_end, duration[, ...])

Get the gradually changed input current.

wiener_process(duration[, dt, n, t_start, ...])

Stimulus sampled from a Wiener process, i.e. drawn from standard normal distribution N(0, sqrt(dt)).

ou_process(mean, sigma, tau, duration[, dt, ...])

Ornstein–Uhlenbeck input.

sinusoidal_input(amplitude, frequency, duration)

Sinusoidal input.

square_input(amplitude, frequency, duration)

Oscillatory square input.

brainpy.connect module#

This module provides methods to construct connectivity between neuron groups. You can access them through brainpy.connect.XXX.

Base Class#

set_default_dtype([mat_dtype, idx_dtype])

Set the default dtype.

csr2csc(csr, post_num[, data])

Convert csr to csc.

csr2mat(csr, num_pre, num_post)

convert (indices, indptr) to a dense matrix.

mat2csr(dense)

convert a dense matrix to (indices, indptr).

ij2csr(pre_ids, post_ids, num_pre)

convert pre_ids, post_ids to (indices, indptr).

MAT_DTYPE

alias of numpy.bool_

IDX_DTYPE

alias of numpy.uint32

Connector()

Base Synaptic Connector Class.

TwoEndConnector()

Synaptic connector to build synapse connections between two neuron groups.

OneEndConnector()

Synaptic connector to build synapse connections within a population of neurons.

CONN_MAT

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE_IDS

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST_IDS

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE2POST

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST2PRE

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE2SYN

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST2SYN

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

SUPPORTED_SYN_STRUCTURE

Built-in mutable sequence.

Custom Connections#

MatConn(conn_mat)

Connector built from the dense connection matrix.

IJConn(i, j)

Connector built from the pre_ids and post_ids connections.

SparseMatConn(csr_mat)

Connector built from the sparse connection matrix

Random Connections#

FixedProb(prob[, include_self, seed])

Connect the post-synaptic neurons with fixed probability.

FixedPreNum(num[, include_self, seed])

Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.

FixedPostNum(num[, include_self, seed])

Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.

GaussianProb(sigma[, encoding_values, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function.

SmallWorld(num_neighbor, prob[, directed, ...])

Build a Watts–Strogatz small-world graph.

ScaleFreeBA(m[, directed, seed])

Build a random graph according to the Barabási–Albert preferential attachment model.

ScaleFreeBADual(m1, m2, p[, directed, seed])

Build a random graph according to the dual Barabási–Albert preferential attachment model.

PowerLaw(m, p[, directed, seed])

Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering.

Regular Connections#

One2One()

Connect two neuron groups one by one.

All2All([include_self])

Connect each neuron in first group to all neurons in the post-synaptic neuron groups.

GridFour([include_self])

The nearest four neighbors conn method.

GridEight([include_self])

The nearest eight neighbors conn method.

GridN([N, include_self])

The nearest (2*N+1) * (2*N+1) neighbors conn method.

one2one

Connect two neuron groups one by one.

all2all

Connect each neuron in first group to all neurons in the post-synaptic neuron groups.

grid_four

The nearest four neighbors conn method.

grid_eight

The nearest eight neighbors conn method.

brainpy.initialize module#

This module provides methods to initialize weights. You can access them through brainpy.init.XXX.

Base Class#

Initializer()

Base Initialization Class.

InterLayerInitializer()

The superclass of Initializers that initialize the weights between two layers.

IntraLayerInitializer()

The superclass of Initializers that initialize the weights within a layer.

Regular Initializers#

ZeroInit()

Zero initializer.

OneInit([value])

One initializer.

Identity([value])

Returns the identity matrix.

Random Initializers#

Normal([mean, scale, seed])

Initialize weights with normal distribution.

Uniform([min_val, max_val, seed])

Initialize weights with uniform distribution.

VarianceScaling(scale, mode, distribution[, ...])

KaimingUniform([scale, mode, distribution, ...])

KaimingNormal([scale, mode, distribution, ...])

XavierUniform([scale, mode, distribution, ...])

XavierNormal([scale, mode, distribution, ...])

LecunUniform([scale, mode, distribution, ...])

LecunNormal([scale, mode, distribution, ...])

Orthogonal([scale, axis, seed])

Construct an initializer for uniformly distributed orthogonal matrices.

DeltaOrthogonal([scale, axis])

Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.

Decay Initializers#

GaussianDecay(sigma, max_w[, min_w, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function.

DOGDecay(sigmas, max_ws[, min_w, ...])

Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons.

brainpy.losses module#

This module implements several loss functions.

cross_entropy_loss(logits, targets[, ...])

This criterion combines LogSoftmax and NLLLoss` in one single class.

l1_loos(logits, targets[, reduction])

Creates a criterion that measures the mean absolute error (MAE) between each element in the logits \(x\) and targets \(y\).

l2_loss(predicts, targets)

Computes the L2 loss.

l2_norm(x)

Computes the L2 loss.

huber_loss(predicts, targets[, delta])

Huber loss.

mean_absolute_error(x, y[, axis])

Computes the mean absolute error between x and y.

mean_squared_error(predicts, targets[, axis])

Computes the mean squared error between x and y.

mean_squared_log_error(y_true, y_pred[, axis])

Computes the mean squared logarithmic error between y_true and y_pred.

brainpy.optimizers module#

Optimizers#

Optimizer(lr[, train_vars, name])

Base Optimizer Class.

SGD(lr[, train_vars, name])

Stochastic gradient descent optimizer.

Momentum(lr[, train_vars, momentum, name])

Momentum optimizer.

MomentumNesterov(lr[, train_vars, momentum, ...])

Nesterov accelerated gradient optimizer [2]_.

Adagrad(lr[, train_vars, epsilon, name])

Optimizer that implements the Adagrad algorithm.

Adadelta([train_vars, lr, epsilon, rho, name])

Optimizer that implements the Adadelta algorithm.

RMSProp(lr[, train_vars, epsilon, rho, name])

Optimizer that implements the RMSprop algorithm.

Adam(lr[, train_vars, beta1, beta2, eps, name])

Optimizer that implements the Adam algorithm.

LARS(lr[, train_vars, momentum, ...])

Layer-wise adaptive rate scaling (LARS) optimizer.

Schedulers#

make_schedule(scalar_or_schedule)

Scheduler(lr)

The learning rate scheduler.

Constant(lr)

ExponentialDecay(lr, decay_steps, decay_rate)

InverseTimeDecay(lr, decay_steps, decay_rate)

PolynomialDecay(lr, decay_steps, final_lr[, ...])

PiecewiseConstant(boundaries, values)

brainpy.measure module#

This module aims to provide commonly used analysis methods for simulated neuronal data. You can access them through brainpy.measure.XXX.

cross_correlation(spikes, bin[, dt])

Calculate cross correlation index between neurons.

matrix_correlation(x, y)

Pearson correlation of the lower triagonal of two matrices.

functional_connectivity(activities)

Functional connectivity matrix of timeseries activities.

raster_plot(sp_matrix, times)

Get spike raster plot which displays the spiking activity of a group of neurons over time.

firing_rate(sp_matrix, width[, dt, numpy])

Calculate the mean firing rate over in a neuron group.

voltage_fluctuation(potentials)

Calculate neuronal synchronization via voltage variance.

weighted_correlation(x, y, w)

Weighted Pearson correlation of two data series.

brainpy.running module#

This module provides APIs for brain simulations.

Monitors#

Monitor(variables[, intervals])

The basic Monitor class to store the past variable trajectories.

Parallel Pool#

process_pool(func, all_net_params, nb_process)

Run multiple models in multi-processes.

process_pool_lock(func, all_net_params, ...)

Run multiple models in multi-processes with lock.

Runners#

Runner(target[, monitors, fun_monitors, ...])

Base Runner.

brainpy.tools module#

Type Checking#

check_shape_consistency(shapes[, free_axes, ...])

check_shape_broadcastable(shapes[, ...])

Check whether the given shapes are broadcastable.

check_shape_except_batch(shape1, shape2[, ...])

Check whether two shapes are compatible except the batch size axis.

check_shape(all_shapes[, free_axes])

check_dict_data(a_dict, key_type, val_type)

Check the dictionary data.

check_initializer(initializer[, name, ...])

Check the initializer.

check_connector(connector[, name, allow_none])

Check the connector.

check_float(value[, name, min_bound, ...])

Check float type.

check_integer(value[, name, min_bound, ...])

Check integer type.

check_string(value[, name, candidates, ...])

Check string type.

check_sequence(value[, name, elem_type, ...])

Code Tools#

copy_doc(source_f)

code_lines_to_func(lines, func_name, ...[, ...])

get_identifiers(expr[, include_numbers])

Return all the identifiers in a given string expr, that is everything that matches a programming language variable like expression, which is here implemented as the regexp \b[A-Za-z_][A-Za-z0-9_]*\b.

indent(text[, num_tabs, spaces_per_tab, tab])

deindent(text[, num_tabs, spaces_per_tab, ...])

word_replace(expr, substitutions[, exclude_dot])

Applies a dict of word substitutions.

is_lambda_function(func)

Check whether the function is a lambda function.

get_main_code(func[, codes])

Get the main function _code string.

get_func_source(func)

change_func_name(f, name)

Error Tools#

check_error_in_jit(pred, err_f[, err_arg])

Check errors in a jit function.

Other Tools#

to_size(x)

rtype

Optional[Tuple[int]]

size2num(size)

timeout(s)

Add a timeout parameter to a function and return it.

init_progress_bar(duration, dt[, report, ...])

Setup a progress bar.

numba_jit([f])

DictPlus(*args, **kwargs)

Python dictionaries with advanced dot notation access.

brainpy.compat module#

Brain Objects#

DynamicalSystem(*args, **kwargs)

Dynamical System.

Container(*args, **kwargs)

Container.

Network(*args, **kwargs)

Network.

ConstantDelay(*args, **kwargs)

Constant Delay.

NeuGroup(*args, **kwargs)

Neuron group.

TwoEndConn(*args, **kwargs)

Two-end synaptic connection.

Integrators#

set_default_odeint(method)

Set default ode integrator.

set_default_sdeint(method)

Set default sde integrator.

get_default_odeint()

Get default ode integrator.

get_default_sdeint()

Get default sde integrator.

Layers#

Module([name])

Basic module class.

Models#

LIF(*args, **kwargs)

LIF neuron model.

AdExIF(*args, **kwargs)

AdExIF neuron model.

Izhikevich(*args, **kwargs)

Izhikevich neuron model.

ExpCOBA(*args, **kwargs)

ExpCOBA synapse model.

ExpCUBA(*args, **kwargs)

ExpCUBA synapse model.

DeltaSynapse(*args, **kwargs)

Delta synapse model.

Runners#

IntegratorRunner(*args, **kwargs)

Integrator runner class.

DSRunner(*args, **kwargs)

Dynamical system runner class.

StructRunner(*args, **kwargs)

Dynamical system runner class.

ReportRunner(*args, **kwargs)

Dynamical system runner class.

Monitor#

Monitor(*args, **kwargs)

Monitor class.

Release notes (brainpy)#

brainpy 2.x (LTS)#

Version 2.1.11 (2022.05.15)#

What’s Changed#

Full Changelog: V2.1.10…V2.1.11

Version 2.1.10 (2022.05.05)#

What’s Changed#

Full Changelog: V2.1.8…V2.1.10

Version 2.1.8 (2022.04.26)#

What’s Changed#

Full Changelog: V2.1.7…V2.1.8

Version 2.1.7 (2022.04.22)#

What’s Changed#

Full Changelog: V2.1.5…V2.1.7

Version 2.1.5 (2022.04.18)#

What’s Changed#

Full Changelog: V2.1.4…V2.1.5

Version 2.1.4 (2022.04.04)#

What’s Changed#

Full Changelog: V2.1.3…V2.1.4

Version 2.1.3 (2022.03.27)#

This release improves the functionality and usability of BrainPy. Core changes include

  • support customization of low-level operators by using Numba

  • fix bugs

What’s Changed#

Full Changelog : V2.1.2…V2.1.3

Version 2.1.2 (2022.03.23)#

This release improves the functionality and usability of BrainPy. Core changes include

  • support rate-based whole-brain modeling

  • add more neuron models, including rate neurons/synapses

  • support Python 3.10

  • improve delays etc. APIs

What’s Changed#

Full Changelog: V2.1.1…V2.1.2

Version 2.1.1 (2022.03.18)#

This release continues to update the functionality of BrainPy. Core changes include

  • numerical solvers for fractional differential equations

  • more standard brainpy.nn interfaces

New Features#
  • Numerical solvers for fractional differential equations
    • brainpy.fde.CaputoEuler

    • brainpy.fde.CaputoL1Schema

    • brainpy.fde.GLShortMemory

  • Fractional neuron models
    • brainpy.dyn.FractionalFHR

    • brainpy.dyn.FractionalIzhikevich

  • support shared_kwargs in RNNTrainer and RNNRunner

Version 2.1.0 (2022.03.14)#

Highlights#

We are excited to announce the release of BrainPy 2.1.0. This release is composed of nearly 270 commits since 2.0.2, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.1.0 updates are focused on improving usability, functionality, and stability of BrainPy. Highlights of version 2.1.0 include:

  • New module brainpy.dyn for dynamics building and simulation. It is composed of many neuron models, synapse models, and others.

  • New module brainpy.nn for neural network building and training. It supports to define reservoir models, artificial neural networks, ridge regression training, and back-propagation through time training.

  • New module brainpy.datasets for convenient dataset construction and initialization.

  • New module brainpy.integrators.dde for numerical integration of delay differential equations.

  • Add more numpy-like operators in brainpy.math module.

  • Add automatic continuous integration on Linux, Windows, and MacOS platforms.

  • Fully update brainpy documentation.

  • Fix bugs on brainpy.analysis and brainpy.math.autograd

Incompatible changes#
  • Remove brainpy.math.numpy module.

  • Remove numba requirements

  • Remove matplotlib requirements

  • Remove steps in brainpy.dyn.DynamicalSystem

  • Remove travis CI

New Features#
  • brainpy.ddeint for numerical integration of delay differential equations, the supported methods include:

    • Euler

    • MidPoint

    • Heun2

    • Ralston2

    • RK2

    • RK3

    • Heun3

    • Ralston3

    • SSPRK3

    • RK4

    • Ralston4

    • RK4Rule38

  • set default int/float/complex types
    • brainpy.math.set_dfloat()

    • brainpy.math.set_dint()

    • brainpy.math.set_dcomplex()

  • Delay variables
    • brainpy.math.FixedLenDelay

    • brainpy.math.NeutralDelay

  • Dedicated operators
    • brainpy.math.sparse_matmul()

  • More numpy-like operators

  • Neural network building brainpy.nn

  • Dynamics model building and simulation brainpy.dyn

Version 2.0.2 (2022.02.11)#

There are important updates by Chaoming Wang in BrainPy 2.0.2.

  • provide pre2post_event_prod operator

  • support array creation from a list/tuple of JaxArray in brainpy.math.asarray and brainpy.math.array

  • update brainpy.ConstantDelay, add .latest and .oldest attributes

  • add brainpy.IntegratorRunner support for efficient simulation of brainpy integrators

  • support auto finding of RandomState when JIT SDE integrators

  • fix bugs in SDE exponential_euler method

  • move parallel running APIs into brainpy.simulation

  • add brainpy.math.syn2post_mean, brainpy.math.syn2post_softmax, brainpy.math.pre2post_mean and brainpy.math.pre2post_softmax operators

Version 2.0.1 (2022.01.31)#

Today we release BrainPy 2.0.1. This release is composed of over 70 commits since 2.0.0, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.0.0 updates are focused on improving documentation and operators. Core changes include:

  • Improve brainpylib operators

  • Complete documentation for programming system

  • Add more numpy APIs

  • Add jaxfwd in autograd module

  • And other changes

Version 2.0.0.1 (2022.01.05)#

  • Add progress bar in brainpy.StructRunner

Version 2.0.0 (2021.12.31)#

Start a new version of BrainPy.

Highlight#

We are excited to announce the release of BrainPy 2.0.0. This release is composed of over 260 commits since 1.1.7, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.0.0 updates are focused on improving performance, usability and consistence of BrainPy. All the computations are migrated into JAX. Model building, simulation, training and analysis are all based on JAX. Highlights of version 2.0.0 include:

  • brainpylib are provided to dedicated operators for brain dynamics programming

  • Connection APIs in brainpy.conn module are more efficient.

  • Update analysis tools for low-dimensional and high-dimensional systems in brainpy.analysis module.

  • Support more general Exponential Euler methods based on automatic differentiation.

  • Improve the usability and consistence of brainpy.math module.

  • Remove JIT compilation based on Numba.

  • Separate brain building with brain simulation.

Incompatible changes#
  • remove brainpy.math.use_backend()

  • remove brainpy.math.numpy module

  • no longer support .run() in brainpy.DynamicalSystem (see New Features)

  • remove brainpy.analysis.PhasePlane (see New Features)

  • remove brainpy.analysis.Bifurcation (see New Features)

  • remove brainpy.analysis.FastSlowBifurcation (see New Features)

New Features#
  • Exponential Euler method based on automatic differentiation
    • brainpy.ode.ExpEulerAuto

  • Numerical optimization based low-dimensional analyzers:
    • brainpy.analysis.PhasePlane1D

    • brainpy.analysis.PhasePlane2D

    • brainpy.analysis.Bifurcation1D

    • brainpy.analysis.Bifurcation2D

    • brainpy.analysis.FastSlow1D

    • brainpy.analysis.FastSlow2D

  • Numerical optimization based high-dimensional analyzer:
    • brainpy.analysis.SlowPointFinder

  • Dedicated operators in brainpy.math module:
    • brainpy.math.pre2post_event_sum

    • brainpy.math.pre2post_sum

    • brainpy.math.pre2post_prod

    • brainpy.math.pre2post_max

    • brainpy.math.pre2post_min

    • brainpy.math.pre2syn

    • brainpy.math.syn2post

    • brainpy.math.syn2post_prod

    • brainpy.math.syn2post_max

    • brainpy.math.syn2post_min

  • Conversion APIs in brainpy.math module:
    • brainpy.math.as_device_array()

    • brainpy.math.as_variable()

    • brainpy.math.as_jaxarray()

  • New autograd APIs in brainpy.math module:
    • brainpy.math.vector_grad()

  • Simulation runners:
    • brainpy.ReportRunner

    • brainpy.StructRunner

    • brainpy.NumpyRunner

  • Commonly used models in brainpy.models module
    • brainpy.models.LIF

    • brainpy.models.Izhikevich

    • brainpy.models.AdExIF

    • brainpy.models.SpikeTimeInput

    • brainpy.models.PoissonInput

    • brainpy.models.DeltaSynapse

    • brainpy.models.ExpCUBA

    • brainpy.models.ExpCOBA

    • brainpy.models.AMPA

    • brainpy.models.GABAa

  • Naming cache clean: brainpy.clear_name_cache

  • add safe in-place operations of update() method and .value assignment for JaxArray

Documentation#
  • Complete tutorials for quickstart

  • Complete tutorials for dynamics building

  • Complete tutorials for dynamics simulation

  • Complete tutorials for dynamics training

  • Complete tutorials for dynamics analysis

  • Complete tutorials for API documentation

brainpy 1.1.x (LTS)#

If you are using brainpy==1.x, you can find documentation, examples, and models through the following links:

Version 1.1.7 (2021.12.13)#

  • fix bugs on numpy_array() conversion in brainpy.math.utils module

Version 1.1.5 (2021.11.17)#

API changes:

  • fix bugs on ndarray import in brainpy.base.function.py

  • convenient ‘get_param’ interface brainpy.simulation.layers

  • add more weight initialization methods

Doc changes:

  • add more examples in README

Version 1.1.4#

API changes:

  • add .struct_run() in DynamicalSystem

  • add numpy_array() conversion in brainpy.math.utils module

  • add Adagrad, Adadelta, RMSProp optimizers

  • remove setting methods in brainpy.math.jax module

  • remove import jax in brainpy.__init__.py and enable jax setting, including

    • enable_x64()

    • set_platform()

    • set_host_device_count()

  • enable b=None as no bias in brainpy.simulation.layers

  • set int_ and float_ as default 32 bits

  • remove dtype setting in Initializer constructor

Doc changes:

  • add optimizer in “Math Foundation”

  • add dynamics training docs

  • improve others

Version 1.1.3#

  • fix bugs of JAX parallel API imports

  • fix bugs of post_slice structure construction

  • update docs

Version 1.1.2#

  • add pre2syn and syn2post operators

  • add verbose and check option to Base.load_states()

  • fix bugs on JIT DynamicalSystem (numpy backend)

Version 1.1.1#

  • fix bugs on symbolic analysis: model trajectory

  • change absolute access in the variable saving and loading to the relative access

  • add UnexpectedTracerError hints in JAX transformation functions

Version 1.1.0 (2021.11.08)#

This package releases a new version of BrainPy.

Highlights of core changes:

math module#
  • support numpy backend

  • support JAX backend

  • support jit, vmap and pmap on class objects on JAX backend

  • support grad, jacobian, hessian on class objects on JAX backend

  • support make_loop, make_while, and make_cond on JAX backend

  • support jit (based on numba) on class objects on numpy backend

  • unified numpy-like ndarray operation APIs

  • numpy-like random sampling APIs

  • FFT functions

  • gradient descent optimizers

  • activation functions

  • loss function

  • backend settings

base module#
  • Base for whole Version ecosystem

  • Function to wrap functions

  • Collector and TensorCollector to collect variables, integrators, nodes and others

integrators module#
  • class integrators for ODE numerical methods

  • class integrators for SDE numerical methods

simulation module#
  • support modular and composable programming

  • support multi-scale modeling

  • support large-scale modeling

  • support simulation on GPUs

  • fix bugs on firing_rate()

  • remove _i in update() function, replace _i with _dt, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)

  • reimplement the input_step and monitor_step in a more intuitive way

  • support to set dt in the single object level (i.e., single instance of DynamicSystem)

  • common used DNN layers

  • weight initializations

  • refine synaptic connections

brainpy 1.0.x#

Version 1.0.3 (2021.08.18)#

Fix bugs on

  • firing rate measurement

  • stability analysis

Version 1.0.2#

This release continues to improve the user-friendliness.

Highlights of core changes:

  • Remove support for Numba-CUDA backend

  • Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).

  • Add the output message of the step function running error.

  • More powerful support for Monitoring

  • More powerful support for running order scheduling

  • Remove unsqueeze() and squeeze() operations in brainpy.ops

  • Add reshape() operation in brainpy.ops

  • Improve docs for numerical solvers

  • Improve tests for numerical solvers

  • Add keywords checking in ODE numerical solvers

  • Add more unified operations in brainpy.ops

  • Support “@every” in steps and monitor functions

  • Fix ODE solver bugs for class bounded function

  • Add build phase in Monitor

Version 1.0.1#

  • Fix bugs

Version 1.0.0#

  • NEW VERSION OF BRAINPY

  • Change the coding style into the object-oriented programming

  • Systematically improve the documentation

brainpy 0.x#

Version 0.3.5#

  • Add ‘timeout’ in sympy solver in neuron dynamics analysis

  • Reconstruct and generalize phase plane analysis

  • Generalize the repeat mode of Network to different running duration between two runs

  • Update benchmarks

  • Update detailed documentation

Version 0.3.1#

  • Add a more flexible way for NeuState/SynState initialization

  • Fix bugs of “is_multi_return”

  • Add “hand_overs”, “requires” and “satisfies”.

  • Update documentation

  • Auto-transform range to numba.prange

  • Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models

Version 0.3.0#

Computation API#
  • Rename “brainpy.numpy” to “brainpy.backend”

  • Delete “pytorch”, “tensorflow” backends

  • Add “numba” requirement

  • Add GPU support

Profile setting#
  • Delete “backend” profile setting, add “jit”

Core systems#
  • Delete “autopepe8” requirement

  • Delete the format code prefix

  • Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”

  • Change the “ST” declaration out of “requires”

  • Add “repeat” mode run in Network

  • Change “vector-based” to “mode” in NeuType and SynType definition

Package installation#
  • Remove “pypi” installation, installation now only rely on “conda”

Version 0.2.4#

API changes#
  • Fix bugs

Version 0.2.3#

API changes#
  • Add “animate_1D” in visualization module

  • Add “PoissonInput”, “SpikeTimeInput” and “FreqInput” in inputs module

  • Update phase_portrait_analyzer.py

Models and examples#
  • Add CANN examples

Version 0.2.2#

API changes#
  • Redesign visualization

  • Redesign connectivity

  • Update docs

Version 0.2.1#

API changes#
  • Fix bugs in numba import

  • Fix bugs in numpy mode with scalar model

Version 0.2.0#

API changes#
  • For computation: numpy, numba

  • For model definition: NeuType, SynConn

  • For model running: Network, NeuGroup, SynConn, Runner

  • For numerical integration: integrate, Integrator, DiffEquation

  • For connectivity: One2One, All2All, GridFour, grid_four, GridEight, grid_eight, GridN, FixedPostNum, FixedPreNum, FixedProb, GaussianProb, GaussianWeight, DOG

  • For visualization: plot_value, plot_potential, plot_raster, animation_potential

  • For measurement: cross_correlation, voltage_fluctuation, raster_plot, firing_rate

  • For inputs: constant_current, spike_current, ramp_current.

Models and examples#
  • Neuron models: HH model, LIF model, Izhikevich model

  • Synapse models: AMPA, GABA, NMDA, STP, GapJunction

  • Network models: gamma oscillation

Release notes (brainpylib)#

Version 0.0.5#

  • Support operator customization on GPU by numba

Version 0.0.4#

  • Support operator customization on CPU by numba

Version 0.0.3#

  • Support event_sum() operator on GPU

  • Support event_prod() operator on CPU

  • Support atomic_sum() operator on GPU

  • Support atomic_prod() operator on CPU and GPU

Version 0.0.2#

  • Support event_sum() operator on CPU

  • Support event_sum2() operator on CPU

  • Support atomic_sum() operator on CPU

Indices and tables#