BrainPy documentation#

BrainPy is a highly flexible and extensible framework targeting on the general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, BrainPy supports:

  • JIT compilation and automatic differentiation for class objects.

  • Numerical methods for ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), fractional differential equations (FDEs), etc.

  • Dynamics building with the modular and composable programming interface.

  • Dynamics simulation for various brain objects with parallel supports.

  • Dynamics training with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc.

  • Dynamics analysis for low- and high-dimensional systems, including phase plane analysis, bifurcation analysis, linearization analysis, and fixed/slow point finding.

  • And more others ……

Installation#

BrainPy is designed to run cross platforms, including Windows, GNU/Linux, and OSX. It only relies on Python libraries.

Installation with pip#

You can install BrainPy from the pypi. To do so, use:

pip install brainpy

To update the BrainPy version, you can use

pip install -U brainpy

If you want to install the pre-release version (the latest development version) of BrainPy, you can use:

pip install --pre brainpy

Installation from source#

If you decide not to use pip, you can install BrainPy from GitHub, or OpenI.

To do so, use:

pip install git+https://github.com/PKU-NIP-Lab/BrainPy

# or

pip install git+https://git.openi.org.cn/OpenI/BrainPy

Dependency 1: NumPy#

In order to make BrainPy work normally, users should install several dependent Python packages.

The basic function of BrainPy only relies on NumPy, which is very easy to install through pip or conda:

pip install numpy

# or

conda install numpy

Dependency 2: JAX#

BrainPy relies on JAX. JAX is a high-performance JIT compiler which enables users to run Python code on CPU, GPU, and TPU devices. Core functionalities of BrainPy (>=2.0.0) have been migrated to the JAX backend.

Linux & MacOS#

Currently, JAX supports Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. The provided binary releases of jax and jaxlib for Linux and macOS systems are available at

If you want to install a CPU-only version of jax and jaxlib, you can run

pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Next, run

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Alternatively, you can download the preferred release “.whl” file for jaxlib from the above release links, and install it via pip:

pip install xxx-0.3.14-xxx.whl

pip install jax==0.3.14

Note

Note that the versions of jaxlib and jax should be consistent.

For example, if you are using jax==0.3.14, you would better install jax==0.3.14.

Windows#

For Windows users, jax and jaxlib can be installed from the community supports. Specifically, you can install jax and jaxlib through:

pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html

If you are using GPU, you can install GPU-versioned wheels through:

pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html

Alternatively, you can manually install you favourite version of jax and jaxlib by downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html . Then install it via pip:

pip install xxx-0.3.14-xxx.whl

pip install jax==0.3.14

WSL#

Moreover, for Windows 10+ system, we recommend using Windows Subsystem for Linux (WSL). The installation guide can be found in WSL Installation Guide for Windows 10/11. Then, you can install JAX in WSL just like the installation step in Linux/MacOs.

Dependency 3: brainpylib#

Many customized operators in BrainPy are implemented in brainpylib. brainpylib can also be installed through pypi.

pip install brainpylib

For GPU operators, you should compile brainpylib from source. The details please see Compile GPU operators in brainpylib.

Other Dependency#

In order to get full supports of BrainPy, we recommend you install the following packages:

  • Numba: needed in some NumPy-based computations

pip install numba

# or

conda install numba
  • matplotlib: required in some visualization functions, but now it is recommended that users explicitly import matplotlib for visualization

pip install matplotlib

# or

conda install matplotlib

Simulating a Brain Dynamics Model#

@Xiaoyu Chen @Chaoming Wang

One of the most important approaches of studying brain dynamics is building a dynamic model and doing simulation. Generally, there are two ways to construct a dynamic model. The first one is called spiking models, which attempt to finely simulate the activity of each neuron in the target population. They are named spiking models because the simulation process records the precise timing of spiking of every neuron. The second is called rate models, which regard a population of neurons with similar properties as a single firing unit and examine the firing rate of this population. In this section, we will illustrate how to build and simulate a spiking neural network, e.i. SNN.

To begin with, the BrainPy package should be imported:

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.1'

Simulating an E-I balanced network#

Building an E-I balanced network#

Firstly, let’s try to build an E-I balanced network. It was proposed to interpret the irregular firing of neurons in the cortical area [1]. Since the structure of an E-I balanced network is relatively simple, it is a good practice that helps users to learn the basic paradigm of brain dynamic simulation in BrainPy. The structure of a E-I balanced network is as follows:

The stucture of an E-I Balanced Network

An E-I balanced network is composed of two neuron groups and the synaptic connections between them. Specifically, they include:

  1. a group of excitatory neurons, \(\mathrm{E}\),

  2. a group of inhibitory neurons, \(\mathrm{I}\),

  3. synaptic connections within the excitatory and inhibitory neuron groups, respectively, and

  4. the inter-connections between these two groups.

To construct the network, we need to define these components one by one. BrainPy provides plenty of handy built-in models for brain dynamic simulation. They are contained in brainpy.dyn. Let’s choose the simplest yet the most canonical neuron model, the Leaky Integrate-and-Fire (LIF) model, to build the excitatory and inhibitory neuron groups:

E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60.,
                   tau=20., tau_ref=5., method='exp_auto',
                   V_initializer=bp.init.Normal(-60., 2.))

I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60.,
                   tau=20., tau_ref=5., method='exp_auto',
                   V_initializer=bp.init.Normal(-60., 2.))

When defining the LIF neuron group, the parameters can be tuned according to users’ need. The first parameter denotes the number of neurons. Here the ratio of excitatory and inhibitory neurons is set to 4:1. V_rest denotes the resting potential, V_th denotes the firing threshold, V_reset denotes the reset value after firing, tau is the time constant, and tau_ref is the duration of the refractory period. method refers to the numerical integration method to be used in simulation.

Then the synaptic connections between these two groups can be defined as follows:

E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6,
                              tau=5., output=bp.synouts.COBA(E=0.),
                              method='exp_auto')

E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6,
                              tau=5., output=bp.synouts.COBA(E=0.),
                              method='exp_auto')

I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7,
                              tau=10., output=bp.synouts.COBA(E=-80.),
                              method='exp_auto')

I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7,
                              tau=10., output=bp.synouts.COBA(E=-80.),
                              method='exp_auto')

Here we use the Exponential synapse model (bp.synapses.Exponential) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use bp.conn.FixedProb, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in Synaptic Connection). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model.

After defining all the components, they can be combined to form a network:

net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I)

In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (E and I) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).

We have successfully constructed an E-I balanced network by using BrainPy’s biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, brainpy.dyn.Network() is a simple example of customizing a network model. Please refer to Dynamic Simulation for more information.

Running a simulation#

After building a SNN, we can use it for dynamic simulation. To run a simulation, we need to wrap the network model into a runner first. BrainPy provides DSRunner in brainpy.dyn, which will be expanded in the Runners tutorial. Users can initialize DSRunner as followed:

runner = bp.DSRunner(net,
                     monitors=['E.spike', 'I.spike'],
                     inputs=[('E.input', 20.), ('I.input', 20.)],
                     dt=0.1)

To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the spike variable in the E and I LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration dt (with the default value of 0.1) can also be specified.

More details of how to give inputs and monitors please refer to Dynamic Simulation.

After creating the runner, we can run a simulation by calling the runner:

runner.run(100)

where the calling function receives the simulation time (usually in milliseconds) as the input. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to Just-In-Time Compilation for more details.

The simulation results are stored as NumPy arrays in the monitors, and can be visualized easily:

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4.5))

plt.subplot(121)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=False)
plt.subplot(122)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], show=True)
_images/2d47ee34e7c19207de179b26305f42a30adf055066484a565b2e95ef2cb5d743.png

In the code above, brianpy.visualize contains some useful functions to visualize simulation results based on the matplotlib package. Since the simulation results are stored as NumPy arrays, users can directly use matplotlib for visualization.

Simulating a decision-making network#

Building a decision-making network#

After learning how to build a E-I balanced network, we can try to handle a more complex model. In 2002, Wang proposed a decision-making model that could choose between two conflict inputs by accumulating evidence over time [2].

The structure of a decision-making network is as follows. Similar to the E-I balanced network, the decision-making network contains an excitatory and an inhibitory neuron group, forming connections within each group and between each other. What is different is that there are two specific subpopulation of neurons, A and B, that receive conflict inputs from outside (other brain areas). After given the external inputs, if the activity of A prevail over B, it means the network chooses option A, and vice versa.

The stucture of an E-I Balanced Network

To construct a decision-making network, we should build all neuron groups:

  1. Two excitatory neuron groups with different selectivity, \(\mathrm{A}\) and \(\mathrm{B}\), and other excitatory neurons, \(\mathrm{N}\);

  2. An inhibitory neuron group, \(\mathrm{I}\);

  3. Neurons generating external inputs \(\mathrm{I_A}\) and \(\mathrm{I_B}\);

  4. Neurons generating noise to all neuron groups, \(\mathrm{noise_A}\), \(\mathrm{noise_B}\), \(\mathrm{noise_N}\), and \(\mathrm{noise_I}\).

And the synapse connection between them:

  1. Connection from excitatory neurons to others, \(\mathrm{A2A}\), \(\mathrm{A2B}\), \(\mathrm{A2N}\), \(\mathrm{A2I}\), \(\mathrm{B2A}\), \(\mathrm{B2B}\), \(\mathrm{B2N}\), \(\mathrm{B2I}\), and \(\mathrm{N2A}\), \(\mathrm{N2B}\), \(\mathrm{N2N}\), \(\mathrm{N2I}\);

  2. Connection from inhibitory neurons to others, \(\mathrm{I2A}\), \(\mathrm{I2B}\), \(\mathrm{I2N}\), \(\mathrm{I2I}\);

  3. Connection from external inputs to selective neuron groups, \(\mathrm{IA2A}\), \(\mathrm{IB2B}\);

  4. Connection from noise neurons to excitatory and inhibitory neurons, \(\mathrm{noise2A}\), \(\mathrm{noise2B}\), \(\mathrm{noise2N}\), \(\mathrm{noise2I}\).

Now let’s build these neuron groups and connections.

First of all, to imitate the biophysical experiments, we define three periods:

pre_stimulus_period = 100.  # time before the external simuli are given
stimulus_period = 1000.  # time within which the external simuli are given
delay_period = 500.  # time after the external simuli are removed
total_period = pre_stimulus_period + stimulus_period + delay_period

To build \(\mathrm{I_A}\) and \(\mathrm{I_B}\), we shall define a class of neuron groups that can generate stochastic Possion stimulu. To define neuron groups, they should inherit from brainpy.dyn.NeuGroup.

class PoissonStim(bp.NeuGroup):
  def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs):
    super(PoissonStim, self).__init__(size=size, **kwargs)

    # initialize parameters
    self.freq_mean = freq_mean
    self.freq_var = freq_var
    self.t_interval = t_interval

    # initialize variables
    self.freq = bm.Variable(bm.zeros(1))
    self.freq_t_last_change = bm.Variable(bm.ones(1) * -1e7)
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()

  def update(self, tdi):
    in_interval = bm.logical_and(pre_stimulus_period < tdi.t, tdi.t < pre_stimulus_period + stimulus_period)
    freq = bm.where(in_interval, self.freq[0], 0.)
    change = bm.logical_and(in_interval, (tdi.t - self.freq_t_last_change[0]) >= self.t_interval)
    self.freq[:] = bm.where(change, self.rng.normal(self.freq_mean, self.freq_var), freq)
    self.freq_t_last_change[:] = bm.where(change, tdi.t, self.freq_t_last_change[0])
    self.spike.value = self.rng.random(self.num) < self.freq[0] * tdi.dt / 1000.

Because there are too many neuron groups and connections, it will be much clearer if we define a new network class inheriting brainpy.dyn.Network to accommodate all these neurons and synapses:

class DecisionMaking(bp.Network):
  def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, dt=bm.get_dt()):
    super(DecisionMaking, self).__init__()

    # initialize neuron-group parameters
    num_exc = int(1600 * scale)
    num_inh = int(400 * scale)
    num_A = int(f * num_exc)
    num_B = int(f * num_exc)
    num_N = num_exc - num_A - num_B
    poisson_freq = 2400.  # Hz

    # initialize synapse parameters
    w_pos = 1.7
    w_neg = 1. - f * (w_pos - 1.) / (1. - f)
    g_ext2E_AMPA = 2.1  # nS
    g_ext2I_AMPA = 1.62  # nS
    g_E2E_AMPA = 0.05 / scale  # nS
    g_E2I_AMPA = 0.04 / scale  # nS
    g_E2E_NMDA = 0.165 / scale  # nS
    g_E2I_NMDA = 0.13 / scale  # nS
    g_I2E_GABAa = 1.3 / scale  # nS
    g_I2I_GABAa = 1.0 / scale  # nS

    # parameters of the AMPA synapse
    ampa_par = dict(delay_step=int(0.5 / dt), tau=2.0, output=bp.synouts.COBA(E=0.))

    # parameters of the GABA synapse
    gaba_par = dict(delay_step=int(0.5 / dt), tau=5.0, output=bp.synouts.COBA(E=-70.))

    # parameters of the NMDA synapse
    nmda_par = dict(delay_step=int(0.5 / dt), tau_decay=100, tau_rise=2.,
                    a=0.5, output=bp.synouts.MgBlock(E=0., cc_Mg=1.))

    # excitatory and inhibitory neuron groups, A, B, N, and I
    A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
                       tau_ref=2., V_initializer=bp.init.OneInit(-70.))
    B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
                       tau_ref=2., V_initializer=bp.init.OneInit(-70.))
    N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
                       tau_ref=2., V_initializer=bp.init.OneInit(-70.))
    I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05,
                       tau_ref=1., V_initializer=bp.init.OneInit(-70.))

    # neurons generating external inputs, I_A and I_B
    IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence)
    IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence)

    # noise neurons
    self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq)
    self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq)
    self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq)
    self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq)

    # connection from excitatory neurons to others
    self.N2B_AMPA = bp.synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
    self.N2A_AMPA = bp.synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
    self.N2N_AMPA = bp.synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
    self.N2I_AMPA = bp.synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
    self.N2B_NMDA = bp.synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
    self.N2A_NMDA = bp.synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
    self.N2N_NMDA = bp.synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
    self.N2I_NMDA = bp.synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)

    self.B2B_AMPA = bp.synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par)
    self.B2A_AMPA = bp.synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
    self.B2N_AMPA = bp.synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
    self.B2I_AMPA = bp.synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
    self.B2B_NMDA = bp.synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par)
    self.B2A_NMDA = bp.synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
    self.B2N_NMDA = bp.synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
    self.B2I_NMDA = bp.synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)

    self.A2B_AMPA = bp.synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par)
    self.A2A_AMPA = bp.synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par)
    self.A2N_AMPA = bp.synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par)
    self.A2I_AMPA = bp.synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par)
    self.A2B_NMDA = bp.synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par)
    self.A2A_NMDA = bp.synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par)
    self.A2N_NMDA = bp.synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par)
    self.A2I_NMDA = bp.synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par)

    # connection from inhibitory neurons to others
    self.I2B = bp.synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
    self.I2A = bp.synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
    self.I2N = bp.synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par)
    self.I2I = bp.synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, **gaba_par)

    # connection from external inputs to selective neuron groups
    self.IA2A = bp.synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
    self.IB2B = bp.synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)

    # connectioni from noise neurons to excitatory and inhibitory neurons
    self.noise2B = bp.synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
    self.noise2A = bp.synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
    self.noise2N = bp.synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par)
    self.noise2I = bp.synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, **ampa_par)

    # add A, B, I, N to the class
    self.A = A
    self.B = B
    self.N = N
    self.I = I
    self.IA = IA
    self.IB = IB

Though the code seems longer than the E-I balanced network, the basic building paradigm is the same: building neuron groups and the connections among them.

Running a simulation#

After building it, the simulation process will be much the same as running a E-I balanced network. First we should wrap the network into a runner:

net = DecisionMaking(scale=1., coherence=25.6, mu0=40.)
runner = bp.DSRunner(net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'])

Then we call the runner to run the simulation:

runner.run(total_period)

Finally, we visualize the simulation result by using matplotlib:

fig, gs = plt.subplots(4, 1, figsize=(10, 12), sharex='all')
t_start = 0.

# the raster plot of A
fig.add_subplot(gs[0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['A.spike'], markersize=1)
plt.title("Spiking activity of group A")
plt.ylabel("Neuron Index")

# the raster plot of A
fig.add_subplot(gs[1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['B.spike'], markersize=1)
plt.title("Spiking activity of group B")
plt.ylabel("Neuron Index")

# the firing rate of A and B
fig.add_subplot(gs[2])
rateA = bp.measure.firing_rate(runner.mon['A.spike'], width=10.)
rateB = bp.measure.firing_rate(runner.mon['B.spike'], width=10.)
plt.plot(runner.mon.ts, rateA, label="Group A")
plt.plot(runner.mon.ts, rateB, label="Group B")
plt.ylabel('Firing rate [Hz]')
plt.title("Population activity")
plt.legend()

# the external stimuli
fig.add_subplot(gs[3])
plt.plot(runner.mon.ts, runner.mon['IA.freq'], label="group A")
plt.plot(runner.mon.ts, runner.mon['IB.freq'], label="group B")
plt.title("Input activity")
plt.ylabel("Firing rate [Hz]")
plt.legend()

for i in range(4):
  gs[i].axvline(pre_stimulus_period, linestyle='dashed', color=u'#444444')
  gs[i].axvline(pre_stimulus_period + stimulus_period, linestyle='dashed', color=u'#444444')

plt.xlim(t_start, total_period + 1)
plt.xlabel("Time [ms]")
plt.tight_layout()
plt.show()
_images/6507ca74eeab7f2a4a9833d80dfbec8921054f6c416a6aba6af4384da1cc94c1.png

For more information about brain dynamic simulation, please refer to Dynamics Simulation in the BDP tutorial.

Simulating a firing rate-based network#

Neural mass model#

A neural mass models is a low-dimensional population model of spiking neural networks. It aims to describe the coarse grained activity of large populations of neurons and synapses. Mathematically, it is a dynamical system of non-linear ODEs. A classical neural mass model is the two-dimensional Wilson–Cowan model. This model tracks the activity of an excitatory population of neurons coupled to an inhibitory population. With the augmentation of such models by more realistic forms of synaptic and network interaction they have proved especially successful in providing fits to neuro-imaging data.

Here, let’s try the Wilson-Cowan model.

wc = bp.rates.WilsonCowanModel(2,
                               wEE=16., wIE=15., wEI=12., wII=3.,
                               E_a=1.5, I_a=1.5, E_theta=3., I_theta=3.,
                               method='exp_euler_auto',
                               x_initializer=bm.asarray([-0.2, 1.]),
                               y_initializer=bm.asarray([0.0, 1.]))

runner = bp.DSRunner(wc, monitors=['x', 'y'], inputs=['input', -0.5])
runner.run(10.)

fig, gs = bp.visualize.get_figure(1, 2, 4, 3)
ax = fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='e', ax=ax)
ax = fig.add_subplot(gs[0, 1])
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, plot_ids=[0, 1], legend='i', ax=ax, show=True)
_images/11dc8ff14165795fe187806be1f753231c3ff98305bfe1811ddeb22af58c4260.png

We can see this model at least has two stable states.

Bifurcation diagram

With the automatic analysis module in BrainPy, we can easily inspect the bifurcation digram of the model. Bifurcation diagrams can give us an overview of how different parameters of the model affect its dynamics (the details of the automatic analysis support of BrainPy please see the introduction in Analyzing a Dynamical Model and tutorials in Dynamics Analysis). In this case, we make x_ext as a bifurcation parameter, and try to see how the system behavior changes with the change of x_ext.

bf = bp.analysis.Bifurcation2D(
  wc,
  target_vars={'x': [-0.2, 1.], 'y': [-0.2, 1.]},
  target_pars={'x_ext': [-2, 2]},
  pars_update={'y_ext': 0.},
  resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 40000 candidates
I am trying to filter out duplicate fixed points ...
	Found 579 fixed points.
I am plotting the limit cycle ...
_images/3e1dad4ca5099b115d2fdc12e6ec5fffb01961f882612d90c0aa39cd52bedbef.png _images/b4f3ba0f2cd2eb3f7eb7c3f2c417a98d9d10a682ecd780c97866de989d600d87.png

Similarly, simulating and analyzing a rate-based FitzHugh-Nagumo model is also a piece of cake by using BrainPy.

fhn = bp.rates.FHN(1, method='exp_auto')

bf = bp.analysis.Bifurcation2D(
  fhn,
  target_vars={'x': [-2, 2], 'y': [-2, 2]},
  target_pars={'x_ext': [0, 2]},
  pars_update={'y_ext': 0.},
  resolutions={'x_ext': 0.01}
)
bf.plot_bifurcation()
bf.plot_limit_cycle_by_sim(duration=500)
bf.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 20000 candidates
I am trying to filter out duplicate fixed points ...
	Found 200 fixed points.
I am plotting the limit cycle ...
_images/4f43e52a1a92363d7c7a5d9375cca6f4a3c27289f73d78694b803dcb104bee42.png _images/0c4ccdb2620bf72fcd77d51a8d0f55f94eddfc755f25513e0d70c508c9be2a41.png

In this model, we find that when the external input x_ext has the value in [0.72, 1.4], the model will generate limit cycles. We can verify this by simulation.

runner = bp.DSRunner(fhn, monitors=['x', 'y'], inputs=['input', 1.0])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x')
bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y', show=True)
_images/9250b8b1d816e5f503af8c1600ce9b87090dbec1b1eb7511f4ca995a35fc3535.png

Whole-brain model#

A rate-based whole-brain model is a network model which consists of coupled brain regions. Each brain region is represented by a neural mass model which is connected to other brain regions according to the underlying network structure of the brain, also known as the connectome. In order to illustrate how to use BrainPy’s support for whole-brain modeling, here we provide a processed data in the following link:

Please download the dataset and place it in your favorite PATH.

PATH = './data/hcp.npz'

In general, a dataset for whole-brain modeling consists of the following parts:

1. A structural connectivity matrix which captures the synaptic connection strengths between brain areas. It often derived from DTI tractography of the whole brain. The connectome is then typically parcellated in a preferred atlas (for example the AAL2 atlas) and the number of axonal fibers connecting each brain area with every other area is counted. This number serves as an indication of the synaptic coupling strengths between the areas of the brain.

2. A delay matrix which calculated from the average length of the axonal fibers connecting each brain area with another.

3. A set of functional data that can act as a target for model optimization. Resting-state fMRI offers an easy and fairly unbiased way for calibrating whole-brain models. EEG data could be used as well.

Now, let’s load the dataset.

data = bm.load(PATH)
# The structural connectivity matrix

data['Cmat'].shape
(80, 80)
# The fiber length matrix

data['Dmat'].shape
(80, 80)
# The functional data for 7 subjects

data['FCs'].shape
(7, 80, 80)

Let’s have a look what the data looks like.

import matplotlib.pyplot as plt

plt.rcParams['image.cmap'] = 'plasma'

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.subplots_adjust(wspace=0.28)

im = axs[0].imshow(data['Cmat'])
axs[0].set_title("Connection matrix")
fig.colorbar(im, ax=axs[0], fraction=0.046, pad=0.04)
im = axs[1].imshow(data['Dmat'], cmap='inferno')
axs[1].set_title("Fiber length matrix")
fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)
im = axs[2].imshow(data['FCs'][0], cmap='inferno')
axs[2].set_title("Empirical FC of subject 1")
fig.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04)
plt.show()
_images/51d426717d6909aaddcb8415548ab8e527c7b0685beb687bf566fcc25ef3ba04.png

Let’s first get the delay matrix according to the fiber length matrix, the signal transmission speed between areas, and the numerical integration step dt. Here, we assume the axonal transmission speed is 20 and the simulation time step dt=0.1 ms.

sigal_speed = 20.

# the number of the delay steps
delay_mat = data['Dmat'] / sigal_speed / bm.get_dt()
delay_mat = bm.asarray(delay_mat, dtype=bm.int_)

The connectivity matrix can be directly obtained through the structural connectivity matrix, which times a global coupling strength parameter gc. b

gc = 1.

conn_mat = bm.asarray(data['Cmat'] * gc)

# It is necessary to exclude the self-connections
bm.fill_diagonal(conn_mat, 0)

We now are ready to instantiate a whole-brain model with the neural mass model and the dataset the processed before.

class WholeBrainNet(bp.Network):
  def __init__(self, Cmat, Dmat):
    super(WholeBrainNet, self).__init__()

    self.fhn = bp.rates.FHN(
      80,
      x_ou_sigma=0.01,
      y_ou_sigma=0.01,
      method='exp_auto'
    )
    self.syn = bp.synapses.DiffusiveCoupling(
      self.fhn.x,
      self.fhn.x,
      var_to_output=self.fhn.input,
      conn_mat=Cmat,
      delay_steps=Dmat.astype(bm.int_),
      initial_delay_data=bp.init.Uniform(0, 0.05)
    )
net = WholeBrainNet(conn_mat, delay_mat)

runner = bp.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)

The simulated results can be used to estimate the functional correlation matrix.

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
fc = bp.measure.functional_connectivity(runner.mon['fhn.x'])
ax = axs[0].imshow(fc)
plt.colorbar(ax, ax=axs[0])
axs[1].plot(runner.mon.ts, runner.mon['fhn.x'][:, ::5], alpha=0.8)
plt.tight_layout()
plt.show()
_images/9b757c334d4e79cf2991c0b9b1c919457ed2ab7de1e334b751a96c4f24bcc8a3.png

We can compute the element-wise Pearson correlation of the functional connectivity matrices of the simulated data to the empirical data to estimate how well the model captures the inter-areal functional correlations found in empirical resting-state recordings.

scores = [bp.measure.matrix_correlation(fc, fcemp)
          for fcemp in data['FCs']]
print("Correlation per subject:", [f"{s:.2}" for s in scores])
print("Mean FC/FC correlation: {:.2f}".format(bm.mean(bm.asarray(scores))))
Correlation per subject: ['0.61', '0.48', '0.56', '0.5', '0.56', '0.49', '0.44']
Mean FC/FC correlation: 0.52

References#

  1. van Vreeswijk, C., & Sompolinsky, H. (1996). Chaos in neuronal networks with balanced excitatory and inhibitory activity. Science (New York, N.Y.), 274(5293), 1724–1726. https://doi.org/10.1126/science.274.5293.1724

  2. Wang X. J. (2002). Probabilistic decision making by slow reverberation in cortical circuits. Neuron, 36(5), 955–968. https://doi.org/10.1016/s0896-6273(02)01092-9

Training a Brain Dynamics Model#

@Chaoming Wang

In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy provides various interfaces to help users train dynamical systems.

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd

bm.enable_x64()

# bm.set_platform('cpu')
bp.__version__
'2.3.1'
import matplotlib.pyplot as plt

Training a reservoir network model#

For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.

(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).

The difference between the two models is illustrated in the following figure.

(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.

Here, let’s implement a next generation reservoir model to predict the chaotic time series, named as Lorenz attractor. Particularly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.

dt = 0.01
data = bd.chaos.LorenzEq(100, dt=dt)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.xs.flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.ys.flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.zs.flatten()))
plt.ylabel('z')
plt.show()
_images/9492b5db560f9b93d40950076a5228efc4ba5a01eab3db249b6e97f4843a9771.png

Let’s first create a function to get the data.

def get_subset(data, start, end):
    res = {'x': data.xs[start: end],
           'y': data.ys[start: end],
           'z': data.zs[start: end]}
    res = bm.hstack([res['x'], res['y'], res['z']])
    return res.reshape((1, ) + res.shape)

To accomplish this task, we implement a next-generation reservoir model of 4 delay history information with stride of 5, and their quadratic polynomial monomials, same as (Gauthier, et. al., Nature Communications, 2021).

class NGRC(bp.DynamicalSystem):
  def __init__(self, num_in, num_out):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5)
    self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bm.training_mode)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    # other arguments like "x" can be customized by users.
    return self.o(sha, self.r(sha, x))
with bm.environment(bm.batching_mode):
    model = NGRC(num_in=3, num_out=3)

Moreover, we use Ridge Regression method to train the model.

trainer = bp.train.RidgeTrainer(model, alpha=1e-6)

We warm-up the network with 20 ms.

warmup_data = get_subset(data, 0, int(20/dt))

outs = trainer.predict(warmup_data)

# outputs should be an array with the shape of
# (num_batch, num_time, num_out)
outs.shape
(1, 2000, 3)

The training data is the time series from 20 ms to 80 ms. We want the network has the ability to forecast 1 time step ahead.

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)

_ = trainer.fit([x_train, y_train])

Then we test the trained network with the next 20 ms.

x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))

predictions = trainer.predict(x_test)

bp.losses.mean_squared_error(y_test, predictions)
Array(3.62923347e-09, dtype=float64)
def plot_difference(truths, predictions):
    truths = bm.as_numpy(truths)
    predictions = bm.as_numpy(predictions)

    plt.subplot(311)
    plt.plot(truths[0, :, 0], label='Ground Truth')
    plt.plot(predictions[0, :, 0], label='Prediction')
    plt.ylabel('x')
    plt.legend()
    plt.subplot(312)
    plt.plot(truths[0, :, 1], label='Ground Truth')
    plt.plot(predictions[0, :, 1], label='Prediction')
    plt.ylabel('y')
    plt.legend()
    plt.subplot(313)
    plt.plot(truths[0, :, 2], label='Ground Truth')
    plt.plot(predictions[0, :, 2], label='Prediction')
    plt.ylabel('z')
    plt.legend()
    plt.show()
plot_difference(y_test, predictions)
_images/fcda93c5b9e35bb608ab765d4d7eb1e389b1ea225522368cddcf0e36a84576a6.png

We can make the task harder to forecast 10 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
_images/526fec79a1ec16b2ff2f50abec6549d8cf0f1463020641ebd705c25379fcbd8b.png

Or forecast 100 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
_ = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
_images/77c35c1c1fc5ae8100c0e0095d655d14f96e7eed34de5bba66956e75cffe8ead.png

As you see, forecasting larger time step makes the learning more difficult.

Training an artificial recurrent network#

In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.train.BPTT interface.

Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we have a time series of noise data,

noises = bm.random.normal(0, 0.2, size=10)

plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()
_images/41a122747ba770dafc6053fd433deaf56f887f9b2366862b842b05719a14cab5.png

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt:

dt = 0.1
integrals = bm.cumsum(noises) * dt

plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()
_images/65f74776cc84b21a1711917ca656960a1103e69fff64bdb3b5f15dd0776452f8.png

Here, we first define a task which generates the input data and the target integration results.

from functools import partial

dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128


@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def build_inputs_and_targets(mean=0.025, scale=0.01):
  # Create the white noise input
  sample = bm.random.normal(size=(num_batch, 1, 1))
  bias = mean * 2.0 * (sample - 0.5)
  samples = bm.random.normal(size=(num_batch, num_step, 1))
  noise_t = scale / dt ** 0.5 * samples
  inputs = bias + noise_t
  targets = bm.cumsum(inputs, axis=1)
  return inputs, targets


def train_data():
  for _ in range(100):
    yield build_inputs_and_targets()

Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True for the used VanillaRNN instance.

class RNN(bp.DynamicalSystem):
  def __init__(self, num_in, num_hidden):
    super(RNN, self).__init__()
    self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
    self.out = bp.layers.Dense(num_hidden, 1)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    return self.out(sha, self.rnn(sha, x))


with bm.training_environment():
    model = RNN(1, 100)

brainpy.nn.BPTT trainer receives a loss function setting, and an optimizer setting. Loss function can be selected from the brainpy.losses module, or it can be a callable function receives (predictions, targets) argument. Optimizer setting must be an instance of brainpy.optim.Optimizer.

Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.

# define loss function
def loss(predictions, targets, l2_reg=2e-4):
    mse = bp.losses.mean_squared_error(predictions, targets)
    l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
    return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.train.BPTT(model,
                        loss_fun=loss,
                        optimizer=opt)
# train the model
trainer.fit(train_data, num_epoch=30)
Train 0 epoch, use 2.2705 s, loss 0.5431804598636681
Train 1 epoch, use 1.1520 s, loss 0.1111436345136279
Train 2 epoch, use 1.0568 s, loss 0.028502398640104478
Train 3 epoch, use 1.0550 s, loss 0.021557415636524625
Train 4 epoch, use 1.2334 s, loss 0.02103036084318231
Train 5 epoch, use 1.0782 s, loss 0.03621808481894387
Train 6 epoch, use 1.2978 s, loss 0.020830462560546617
Train 7 epoch, use 1.1253 s, loss 0.020349677236923852
Train 8 epoch, use 1.0884 s, loss 0.01999626884753028
Train 9 epoch, use 1.0844 s, loss 0.019711449539100128
Train 10 epoch, use 1.0470 s, loss 0.01948684809571936
Train 11 epoch, use 1.0619 s, loss 0.019206390127710953
Train 12 epoch, use 1.0470 s, loss 0.018981963159608338
Train 13 epoch, use 1.0411 s, loss 0.01878478481278808
Train 14 epoch, use 1.0648 s, loss 0.018614482841389626
Train 15 epoch, use 1.0385 s, loss 0.01834170918632372
Train 16 epoch, use 1.0266 s, loss 0.0329911761349229
Train 17 epoch, use 1.0184 s, loss 0.019153171678052538
Train 18 epoch, use 1.0218 s, loss 0.01789557062710536
Train 19 epoch, use 1.0252 s, loss 0.017628752447072186
Train 20 epoch, use 1.0711 s, loss 0.01743445520243582
Train 21 epoch, use 1.0872 s, loss 0.017260501677986928
Train 22 epoch, use 1.0393 s, loss 0.017055458142235507
Train 23 epoch, use 1.0689 s, loss 0.01688203255689325
Train 24 epoch, use 1.0568 s, loss 0.01670960989088571
Train 25 epoch, use 1.0435 s, loss 0.016556937911224538
Train 26 epoch, use 1.0464 s, loss 0.01639770153466529
Train 27 epoch, use 1.0527 s, loss 0.01624334924290911
Train 28 epoch, use 1.0409 s, loss 0.0160952063182731
Train 29 epoch, use 1.0563 s, loss 0.015950705589392912

The training losses can be retrieved by .get_hist_metric() function.

plt.figure(figsize=(8, 3))
plt.plot(trainer.get_hist_metric(metric='loss'))
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()
_images/e44feecf4371f24c2aec5670f7a98b356c07c6d768c756070ed69a911c8ec76e.png

Finally, let’s try the trained network, and test whether it can generate the correct integration results.

model.reset_state(num_batch)
x, y = build_inputs_and_targets()
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()
_images/bd5a0683531f05c355190766fbd929259d1a521d9a23023107a72cb8effd8d4a.png

Training a spiking neural network#

BrainPy also supports to train spiking neural networks.

In the following, we demonstrate how to use back-propagation algorithms to train spiking neurons with a simple example.

Our model is a simple three layer model:

  • an input layer

  • a LIF layer

  • a readout layer

The synaptic connection between each layer is the Exponenetial synapse model.

class SNN(bp.Network):
  def __init__(self, num_in, num_rec, num_out):
    super(SNN, self).__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in)
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), tau=10.,
                                       output=bp.synouts.CUBA(target_var=None),
                                       g_max=bp.init.KaimingNormal(scale=20.))
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), tau=10.,
                                       output=bp.synouts.CUBA(target_var=None),
                                       g_max=bp.init.KaimingNormal(scale=20.))

    # whole model
    self.model = bp.Sequential(self.i, self.i2r, self.r, self.r2o, self.o)

  def update(self, tdi, spike):
    self.model(tdi, spike)
    return self.o.V.value
with bm.training_environment():
    net = SNN(100, 10, 2)  # out task is a two label classification task

We try to use this simple task to classify a random spiking data into two classes.

num_step = 2000
num_sample = 256
freq = 5  # Hz
mask = bm.random.rand(num_sample, num_step, net.num_in)
x_data = bm.zeros((num_sample, num_step, net.num_in))
x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)

def get_data():
    for _ in range(1):
        yield x_data, y_data

Same as the training of artificial recurrent neural networks, we use Adam optimizer and cross entropy loss to train the model.

opt = bp.optim.Adam(lr=2e-3)

def loss(predicts, targets):
  return bp.losses.cross_entropy_loss(bm.max(predicts, axis=1), targets)


trainer = bp.train.BPTT(net,
                        loss_fun=loss,
                        optimizer=opt)
trainer.fit(train_data=get_data,
            num_report=10,
            num_epoch=200)
Train 10 steps, use 0.9103 s, loss 0.7219832174729163
Train 20 steps, use 0.7449 s, loss 0.6696123267371417
Train 30 steps, use 0.6464 s, loss 0.6491206328569219
Train 40 steps, use 0.5883 s, loss 0.6156547140069775
Train 50 steps, use 0.5729 s, loss 0.5987596785982736
Train 60 steps, use 0.6343 s, loss 0.5862205241316523
Train 70 steps, use 0.7598 s, loss 0.5607236263572535
Train 80 steps, use 0.6077 s, loss 0.5457860326436039
Train 90 steps, use 0.5707 s, loss 0.5264014105800172
Train 100 steps, use 0.5862 s, loss 0.5146514133005329
Train 110 steps, use 0.5722 s, loss 0.5068201255745326
Train 120 steps, use 0.5812 s, loss 0.4896138875231886
Train 130 steps, use 0.5887 s, loss 0.4799118492626251
Train 140 steps, use 0.5656 s, loss 0.47088261120558417
Train 150 steps, use 0.5742 s, loss 0.44685486258925866
Train 160 steps, use 0.5626 s, loss 0.43019316163725896
Train 170 steps, use 0.5736 s, loss 0.4131096257548337
Train 180 steps, use 0.5706 s, loss 0.4043105738199416
Train 190 steps, use 0.5722 s, loss 0.37480466053211214
Train 200 steps, use 0.5926 s, loss 0.36532990528050513

The training loss is continuously decreasing, demonstrating that the network is effectively training.

# visualize the training losses
plt.plot(trainer.get_hist_metric())
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.show()
_images/2df565cb25d2a6eaddcf06bbd9b6d762195f9f8e280ae0f78471573871f83a8c.png

Let’s visualize the trained spiking neurons.

import numpy as np
from matplotlib.gridspec import GridSpec

def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
  plt.figure(figsize=(15, 8))
  gs = GridSpec(*dim)
  mem = 1. * mem
  if spk is not None:
    mem[spk > 0.0] = spike_height
  mem = bm.as_numpy(mem)
  for i in range(np.prod(dim)):
    if i == 0:
      a0 = ax = plt.subplot(gs[i])
    else:
      ax = plt.subplot(gs[i], sharey=a0)
    ax.plot(mem[i])
  plt.tight_layout()
  plt.show()
# get the prediction results and neural activity

runner = bp.DSRunner(
    net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}
)
out = runner.run(inputs=x_data, reset_state=True)
plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
_images/dadaef4032679dfafecabcb2b238809313fbc6024b6b97540f6ec2db5738ae76.png
# the prediction accuracy

m = bm.max(out, axis=1)  # max over time
am = bm.argmax(m, axis=1)  # argmax over output units
acc = bm.mean(y_data == am)  # compare to labels
print("Accuracy %.3f" % acc)
Accuracy 0.910

Analyzing a Brain Dynamics Model#

@Xiaoyu Chen @Chaoming Wang

In BrainPy, defined models can not only be used for simulation, but also be capable of performing automatic dynamics analysis.

BrainPy provides rich interfaces to support analysis, including

Here we will introduce three brief examples of 1-D bifurcation analysis and 2-D phase plane analysis. For more detailsand more examples, please refer to the tutorials of dynamics analysis.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bm.enable_x64()  # it's better to use x64 computation
bp.__version__
'2.3.1'

Bifurcation analysis of a 1D model#

Here, we demonstrate how to perform a bifurcation analysis through a one-dimensional neuron model.

Let’s try to analyze how the external input influences the dynamics of the Exponential Integrate-and-Fire (ExpIF) model. The ExpIF model is a one-variable neuron model whose dynamics is defined by:

\[\begin{split} \tau {\dot {V}}= - (V - V_\mathrm{rest}) + \Delta_T \exp(\frac{V - V_T}{\Delta_T}) + RI \\ \mathrm{if}\, \, V > \theta, \quad V \gets V_\mathrm{reset} \end{split}\]

We can analyze the change of \({\dot {V}}\) with respect to \(V\). First, let’s generate an ExpIF model using pre-defined modules in brainpy.dyn:

expif = bp.neurons.ExpIF(1, delta_T=1.)

The default value of other parameters can be accessed directly by their names:

expif.V_rest, expif.V_T, expif.R, expif.tau
(-65.0, -59.9, 1.0, 10.0)

After defining the model, we can use it for bifurcation analysis.

bif = bp.analysis.Bifurcation1D(
    model=expif,
    target_vars={'V': [-70., -55.]},
    target_pars={'I_ext': [0., 6.]},
    resolutions={'I_ext': 0.01}
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/4146837a7dbf1e6c830f9ad01ec1ad60d8913b4888ef73c737f2d03cdb127097.png

In the Bifurcation1D analyzer, model refers to the model to be analyzed (essentially the analyzer will access the derivative function in the model), target_vars denotes the target variables, target_pars denotes the changing parameters, and resolution determines the resolution of the analysis.

In the image above, there are two lines that “merge” together to form a bifurcation. The dots making up the lines refer to the fixed points of \(\mathrm{d}V/\mathrm{d}t\). On the left of the bifurcation point (where two lines merge together), there are two fixed points where \(\mathrm{d}V/\mathrm{d}t = 0\) given each external input \(I_\mathrm{ext}\). One of them is a stable point, and the other is an unstable one. When \(I_\mathrm{ext}\) increases, the two fixed points move closer to each other, overlap, and finally disappear.

Bifurcation analysis provides insights for the dynamics of the model, for it indicates the number and the change of stable states with respect to different parameters.

Phase plane analysis of a 2D model#

Besides bifurcationi analysis, another important tool is phase plane analysis, which displays the trajectory of the variable point in the vector field. Let’s take the FitzHugh–Nagumo (FHN) neuron model as an example. The dynamics of the FHN model is given by:

\[\begin{split} {\dot {v}}=v-{\frac {v^{3}}{3}}-w+I, \\ \tau {\dot {w}}=v+a-bw. \end{split}\]

Users can easily define a FHN model which is also provided by BrainPy:

fhn = bp.neurons.FHN(1)

Because there are two variables, \(v\) and \(w\), in the FHN model, we shall use 2-D phase plane analysis to visualize how these two variables change over time.

analyzer = bp.analysis.PhasePlane2D(
  model=fhn,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  pars_update={'I_ext': 0.8}, 
  resolutions=0.01,
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)
analyzer.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 866 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 V=-0.2729223248464073, w=0.5338542697673022 is a unstable node.
I am plotting the trajectory ...
_images/283a71cfd04fd008d52941e82930b0e32d25f425db891186f2b15212bf46d457.png

In the PhasePlane2D analyzer, the parameters model, target_vars, and resolution is the same as those in Bifurcation1D. pars_update specifies the parameters to be updated during analysis. After defining the analyzer, users can visualize the nullcline, vector field, fixed points and the trajectory in the image. The phase plane gives users intuitive interpretation of the changes of \(v\) and \(w\) guided by the vector field (violet arrows).

Slow point analysis of a high-dimensional system#

BrainPy is also capable of performing fixed/slow point analysis of high-dimensional systems. Moreover, it can perform automatic linearization analysis around the fixed point.

In the following, we use a gap junction coupled FitzHugh–Nagumo (FHN) network as an example to demonstrate how to find fixed/slow points of a high-dimensional system.

We first define the gap junction coupled FHN network as the normal DynamicalSystem class.

class GJCoupledFHN(bp.DynamicalSystem):
  def __init__(self, num=4, method='exp_auto'):
    super(GJCoupledFHN, self).__init__()

    # parameters
    self.num = num
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5
    self.gjw = 0.0001

    # variables
    self.V = bm.Variable(bm.random.uniform(-2, 2, num))
    self.w = bm.Variable(bm.random.uniform(-2, 2, num))
    self.Iext = bm.Variable(bm.zeros(num))

    # functions
    self.int_V = bp.odeint(self.dV, method=method)
    self.int_w = bp.odeint(self.dw, method=method)

  def dV(self, V, t, w, Iext=0.):
    gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw
    dV = V - V * V * V / 3 - w + Iext + gj
    return dV

  def dw(self, w, t, V):
    dw = (V + self.a - self.b * w) / self.tau
    return dw

  def update(self, tdi):
    t, dt = tdi.get('t'), tdi.get('dt')
    self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
    self.w.value = self.int_w(self.w, t, self.V, dt)
    self.Iext[:] = 0.

Through simulation, we can easily find that this system has a limit cycle attractor, implying that an unstable fixed point exists.

# initialize a network
model = GJCoupledFHN(4)
model.gjw = 0.1

# simulation with an input
Iext = bm.asarray([0., 0., 0., 0.6])
runner = bp.DSRunner(model, monitors=['V'], inputs=['Iext', Iext])
runner.run(300.)

# visualization
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',
                       plot_ids=list(range(model.num)),
                       show=True)
_images/b8df95627166a162a7847688d24573fc7fec5ddb07ffed62571f955720d0c194.png

Let’s try to optimize the fixed points for this system. Note that we only take care of the variables V and w. Different from the low-dimensional analyzer, we should provide the candidate fixed points or initial fixed points when using the high-dimensional analyzer.

# init a slow point finder
finder = bp.analysis.SlowPointFinder(f_cell=model,
                                     target_vars={'V': model.V, 'w': model.w},
                                     inputs=[model.Iext, Iext])

# optimize to find fixed points
finder.find_fps_with_gd_method(
  candidates={'V': bm.random.normal(0., 2., (1000, model.num)),
              'w': bm.random.normal(0., 2., (1000, model.num))},
  tolerance=1e-6,
  num_batch=200,
  optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)),
)

# filter fixed points whose loss is bigger than the threshold
finder.filter_loss(1e-8)

# remove the duplicate fixed points
finder.keep_unique()
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.29 sec, Training loss 0.0003104926
    Batches 201-400 in 0.28 sec, Training loss 0.0002287778
    Batches 401-600 in 0.28 sec, Training loss 0.0001775225
    Batches 601-800 in 0.30 sec, Training loss 0.0001401555
    Batches 801-1000 in 0.30 sec, Training loss 0.0001119446
    Batches 1001-1200 in 0.30 sec, Training loss 0.0000904519
    Batches 1201-1400 in 0.30 sec, Training loss 0.0000738873
    Batches 1401-1600 in 0.30 sec, Training loss 0.0000609509
    Batches 1601-1800 in 0.30 sec, Training loss 0.0000506783
    Batches 1801-2000 in 0.36 sec, Training loss 0.0000424477
    Batches 2001-2200 in 0.29 sec, Training loss 0.0000357793
    Batches 2201-2400 in 0.29 sec, Training loss 0.0000303206
    Batches 2401-2600 in 0.33 sec, Training loss 0.0000258537
    Batches 2601-2800 in 0.28 sec, Training loss 0.0000221875
    Batches 2801-3000 in 0.29 sec, Training loss 0.0000191505
    Batches 3001-3200 in 0.30 sec, Training loss 0.0000166231
    Batches 3201-3400 in 0.30 sec, Training loss 0.0000144943
    Batches 3401-3600 in 0.29 sec, Training loss 0.0000126804
    Batches 3601-3800 in 0.30 sec, Training loss 0.0000111463
    Batches 3801-4000 in 0.29 sec, Training loss 0.0000098656
    Batches 4001-4200 in 0.28 sec, Training loss 0.0000087958
    Batches 4201-4400 in 0.35 sec, Training loss 0.0000078796
    Batches 4401-4600 in 0.31 sec, Training loss 0.0000070861
    Batches 4601-4800 in 0.29 sec, Training loss 0.0000063897
    Batches 4801-5000 in 0.28 sec, Training loss 0.0000057697
    Batches 5001-5200 in 0.28 sec, Training loss 0.0000052188
    Batches 5201-5400 in 0.28 sec, Training loss 0.0000047263
    Batches 5401-5600 in 0.29 sec, Training loss 0.0000042864
    Batches 5601-5800 in 0.28 sec, Training loss 0.0000038972
    Batches 5801-6000 in 0.28 sec, Training loss 0.0000035515
    Batches 6001-6200 in 0.29 sec, Training loss 0.0000032389
    Batches 6201-6400 in 0.29 sec, Training loss 0.0000029477
    Batches 6401-6600 in 0.28 sec, Training loss 0.0000026731
    Batches 6601-6800 in 0.28 sec, Training loss 0.0000024145
    Batches 6801-7000 in 0.28 sec, Training loss 0.0000021735
    Batches 7001-7200 in 0.35 sec, Training loss 0.0000019521
    Batches 7201-7400 in 0.28 sec, Training loss 0.0000017512
    Batches 7401-7600 in 0.28 sec, Training loss 0.0000015672
    Batches 7601-7800 in 0.28 sec, Training loss 0.0000013971
    Batches 7801-8000 in 0.27 sec, Training loss 0.0000012403
    Batches 8001-8200 in 0.27 sec, Training loss 0.0000010954
    Batches 8201-8400 in 0.27 sec, Training loss 0.0000009603
    Stop optimization as mean training loss 0.0000009603 is below tolerance 0.0000010000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 815/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/815 unique fixed points with uniqueness tolerance 0.025.
print('fixed points:', )
finder.fixed_points
fixed points:
{'V': array([[-1.17757852, -1.17757852, -1.17757852, -0.81465053]]),
 'w': array([[-0.59697314, -0.59697314, -0.59697314, -0.14331316]])}
print('fixed point losses:', )
finder.losses
fixed point losses:
array([2.46519033e-32])

Let’s perform the linearization analysis of the found fixed points, and visualize its decomposition results.

_ = finder.compute_jacobians(finder.fixed_points, plot=True)
_images/3b8a9ab971408f50ae02ddd79d4004984e41460c8f600bd12f323001f002ac05.png

This is an unstable fixed point, because one of its eigenvalues has the real part bigger than 1.

Further reading#

  • For more details about how to perform bifurcation analysis and phase plane analysis, please see the tutorial of Low-dimensional Analyzers.

  • A good example of phase plane analysis and bifurcation analysis is the decision-making model, please see the tutorial in Analysis of a Decision-making Model

  • If you want to how to analyze the slow points (or fixed points) of your high-dimensional dynamical models, please see the tutorial of High-dimensional Analyzers

Concept 1: Object-oriented Transformation#

@Chaoming Wang

Most computation in BrainPy relies on JAX. JAX has provided wonderful transformations, including differentiation, vecterization, parallelization and just-in-time compilation, for Python programs. If you are not familiar with it, please see its documentation.

However, JAX only supports functional programming, i.e., transformations for Python functions. This is not what we want. Brain Dynamics Modeling need object-oriented programming.

To meet this requirement, BrainPy defines the interface for object-oriented (OO) transformations. These OO transformations can be easily performed for BrainPy objects.

In this section, let’s talk about the BrainPy concept of object-oriented transformations.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Illustrating example: Training a network#

To illustrate this concept, we need a demonstration example. Here, we choose the popular neural network training as the illustrating case.

In this training case, we want to teach the neural network to correctly classify a random array as two labels (True or False). That is, we have the training data:

num_in = 100
num_sample = 256
X = bm.random.rand(num_sample, num_in)
Y = (bm.random.rand(num_sample) < 0.5).astype(float)

We use a two-layer feedforward network:

class Linear(bp.BrainPyObject):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.num_in = n_in
        self.num_out = n_out
        init = bp.init.XavierNormal()
        self.W = bm.Variable(init((n_in, n_out)))
        self.b = bm.Variable(bm.zeros((1, n_out)))

    def __call__(self, x):
        return x @ self.W + self.b


net = bp.Sequential(Linear(num_in, 20),
                    bm.relu,
                    Linear(20, 2))
print(net)
Sequential(
  [0] Linear0
  [1] relu
  [2] Linear1
)

Here, we use a supervised learning training paradigm.

rng = bm.random.RandomState(123)


# Loss function
@bm.to_object(child_objs=net, dyn_vars=rng)
def loss():
    # shuffle the data
    key = rng.split_key()
    x_data = rng.permutation(X, key=key)
    y_data = rng.permutation(Y, key=key)
    # prediction
    predictions = net(dict(), x_data)
    # loss
    l = bp.losses.cross_entropy_loss(predictions, y_data)
    return l


# Gradient function
grad = bm.grad(loss, grad_vars=net.vars(), return_value=True)

# Optimizer
optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())


# Training step
@bm.to_object(child_objs=(grad, optimizer))
def train(i):
    grads, l = grad()
    optimizer.update(grads)
    return l


num_step = 400
for i in range(0, 4000, num_step):
    # train 400 steps once
    ls = bm.for_loop(train, operands=bm.arange(i, i + num_step))
    print(f'Train {i + num_step} epoch, loss = {bm.mean(ls):.4f}')
Train 400 epoch, loss = 0.6710
Train 800 epoch, loss = 0.5992
Train 1200 epoch, loss = 0.5332
Train 1600 epoch, loss = 0.4720
Train 2000 epoch, loss = 0.4189
Train 2400 epoch, loss = 0.3736
Train 2800 epoch, loss = 0.3335
Train 3200 epoch, loss = 0.2972
Train 3600 epoch, loss = 0.2644
Train 4000 epoch, loss = 0.2346

In the above example, we have seen classical elements in a neural network training, such as

  • net: neural network

  • loss: loss function

  • grad: gradient function

  • optimizer: parameter optimizer

  • train: training step

In BrainPy, all these elements can be defined as class objects and can be used for performing OO transformations.

In essence, the concept of BrainPy object-oriented transformation has three components:

  • BrainPyObject: the base class for object-oriented programming

  • Variable: the varibles in the class object, whose values are ready to be changed/updated during transformation

  • ObjectTransform: the transformations for computation involving BrainPyObject and Variable

BrainPyObject and its Variable#

BrainPyObject is the base class for object-oriented programming in BrainPy. It can be viewed as a container which contains all needed Variable for our computation.

In the above example, Linear object has two Variable: W and b. The net we defined is further composed of two Linear objects. We can expect that four variables can be retrieved from it.

net.vars().keys()
dict_keys(['Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])

An important question is, how to define Variable in a BrainPyObject so that we can retrieve all of them?

Actually, all Variable instance which can be accessed by self. attribue can be retrived from a BrainPyObject recursively. No matter how deep the composition of BrainPyObject, once BrainPyObject instance and their Variable instances can be accessed by self. operation, all of them will be retrieved.

class SuperLinear(bp.BrainPyObject):
    def __init__(self, ):
        super().__init__()
        self.l1 = Linear(10, 20)
        self.v1 = bm.Variable(3)
        
sl = SuperLinear()
# retrieve Variable
sl.vars().keys()
dict_keys(['SuperLinear0.v1', 'Linear2.W', 'Linear2.b'])
# retrieve BrainPyObject
sl.nodes().keys()
dict_keys(['SuperLinear0', 'Linear2'])

However, we cannot access the BrainPyObject or Variable which is in a Python container (like tuple, list, or dict). For this case, we can register our objects and variables through .register_implicit_vars() and .register_implicit_nodes():

class SuperSuperLinear(bp.BrainPyObject):
    def __init__(self, register=False):
        super().__init__()
        self.ss = [SuperLinear(), SuperLinear()]
        self.vv = {'v_a': bm.Variable(3)}
        if register:
            self.register_implicit_nodes(self.ss)
            self.register_implicit_vars(self.vv)
# without register
ssl = SuperSuperLinear(register=False)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys([])
dict_keys(['SuperSuperLinear0'])
# with register
ssl = SuperSuperLinear(register=True)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys(['SuperSuperLinear1.v_a', 'SuperLinear3.v1', 'SuperLinear4.v1', 'Linear5.W', 'Linear5.b', 'Linear6.W', 'Linear6.b'])
dict_keys(['SuperSuperLinear1', 'SuperLinear3', 'SuperLinear4', 'Linear5', 'Linear6'])

Transform a function to BrainPyObject#

Let’s go back to our network training. After the definition of net, we further define a loss function whose computation involves the net object for neural network prediction and a rng Variable for data shuffling.

This Python function is then transformed into a BrainPyObject instance by brainpy.math.to_object interface.

loss
FunAsObject(nodes=[Sequential0],
            num_of_vars=1)

All Variable used in this instance can also be retrieved through:

loss.vars().keys()
dict_keys(['loss0._var0', 'Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])

Note that, when using to_object(), we need to explicitly declare all BrainPyObject and Variable used in this Python function. Due to the recursive retrieval property of BrainPyObject, we only need to specify the latest composition object.

In the above loss object, we do not need to specify two Linear object. Instead, we only need to give the top level object net into to_object() transform.

Similarly, when we transform train function into a BrainPyObject, we just need to point out the grad and opt we have used, rather than the previous loss, net or rng.

BrainPy object-oriented transformations#

BrainPy object-oriented transformations are designed to work on BrainPyObject. These transformations include autograd brainpy.math.grad() and JIT brainpy.math.jit().

In our case, we used two OO transformations provided in BrainPy.

First, grad object is defined with the loss function. Within it, we need to specify what variables we need to compute their gradients through grad_vars.

Note that, the OO transformation of any BrainPyObject results in another BrainPyObject object. Therefore, it can be recersively used as a component to form the larger scope of object-oriented programming and object-oriented transformation.

grad
GradientTransform(target=loss0, 
                  num_of_grad_vars=4, 
                  num_of_dyn_vars=1)

Next, we train 400 steps once by using a for_loop transformation. Different from grad which return a BrainPyObject instance, for_loop direactly returns the loop results.

Concept 2: Dynamical System#

@Chaoming Wang

BrainPy supports modelings in brain simulation and brain-inspired computing.

All these supports are based on one common concept: Dynamical System via brainpy.DynamicalSystem.

Therefore, it is essential to understand:

  1. what is brainpy.DynamicalSystem?

  2. how to define brainpy.DynamicalSystem?

  3. how to run brainpy.DynamicalSystem?

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.3.1'

What is DynamicalSystem?#

All models used in brain simulation and brain-inspired computing is DynamicalSystem.

Note

DynamicalSystem is a subclass of BrainPyOject. Therefore it supports to use object-oriented transformations as stated in the previous tutorial.

A DynamicalSystem defines the updating rule of the model at single time step.

  1. For models with state, DynamicalSystem defines the state transition from \(t\) to \(t+dt\), i.e., \(S(t+dt) = F\left(S(t), x, t, dt\right)\), where \(S\) is the state, \(x\) is input, \(t\) is the time, and \(dt\) is the time step. This is the case for recurrent neural networks (like GRU, LSTM), neuron models (like HH, LIF), or synapse models which are widely used in brain simulation.

  2. However, for models in deep learning, like convolution and fully-connected linear layers, DynamicalSystem defines the input-to-output mapping, i.e., \(y=F\left(x, t\right)\).

How to define DynamicalSystem?#

Keep in mind that the usage of DynamicalSystem has several constraints in BrainPy.

1. .update() function#

First, all DynamicalSystem should implement .update() function, which receives two arguments:

class YourModel(bp.DynamicalSystem):
  def update(self, s, x):
    pass
  • s (or named as others): A dict, to indicate shared arguments across all nodes/layers in the network, like

    • the current time t, or

    • the current running index i, or

    • the current time step dt, or

    • the current phase of training or testing fit=True/False.

  • x (or named as others): The individual input for this node/layer.

We call s as shared arguments because they are same and shared for all nodes/layers. On the contrary, different nodes/layers have different input x.

Example: LIF neuron model for brain simulation

Here we illustrate the first constraint of DynamicalSystem using the Leaky Integrate-and-Fire (LIF) model.

The LIF model is firstly proposed in brain simulation for modeling neuron dynamics. Its equation is given by

\[\begin{split} \begin{aligned} \tau_m \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\ \text{if} \, V(t) \gt V_{th}, V(t) =V_{rest} \end{aligned} \end{split}\]

For the details of the model, users should refer to Wikipedia or other resource.

class LIF_for_BrainSimulation(bp.DynamicalSystem):
  def __init__(self, size, V_rest=0., V_th=1., tau=5., mode=None):
    super().__init__(mode=mode)

    # this model only supports non-batching mode
    bp.check.is_subclass(self.mode, bm.NonBatchingMode)

    # parameters
    self.size = size
    self.V_rest = V_rest
    self.V_th = V_th
    self.tau = tau

    # variables
    self.V = bm.Variable(bm.ones(size) * V_rest)
    self.spike = bm.Variable(bm.zeros(size, dtype=bool))

    # integrate differential equation with exponential euler method
    self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')

  def update(self, s, x):
    # define how the model states update
    # according to the external input
    t, dt = s.get('t'), s.get('dt')
    V = self.integral(self.V, t, x, dt=dt)
    spike = V >= self.V_th
    self.V.value = bm.where(spike, self.V_rest, V)
    self.spike.value = spike
    return spike

2. Computing mode#

Second, explicitly consider which computing mode your DynamicalSystem supports.

Brain simulation usually builds models without batching dimension (we refer to it as non-batching mode, as seen in above LIF model), while brain-inspired computation trains models with a batch of data (batching mode or training mode).

So, to write a model applicable to abroad applications in brain simulation and brain-inspired computing, you need to consider which mode your model supports, one of them, or both of them.

Example: LIF neuron model for both brain simulation and brain-inspired computing

When considering the computing mode, we can program a general LIF model for brain simulation and brain-inspired computing.

To overcome the non-differential property of the spike in the LIF model for brain simulation, i.e., at the code of

spike = V >= self.V_th

LIF models used in brain-inspired computing calculate the spiking state using the surrogate gradient function. Usually, we replace the backward gradient of the spike with a smooth function, like

\[ g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} \]
class LIF(bp.DynamicalSystem):
  def __init__(self, size, f_surrogate=None, V_rest=0., V_th=1., tau=5.,mode=None):
    super().__init__(mode=mode)
    bp.check.is_subclass(self.mode, [bm.NonBatchingMode, bm.BatchingMode, bm.TrainingMode])

    # Parameters
    self.size = size
    self.num = bp.tools.size2num(size)
    self.V_rest = V_rest
    self.V_th = V_th
    self.tau = tau
    if f_surrogate is None:
      f_surrogate = bm.surrogate.inv_square_grad
    self.f_surrogate = f_surrogate

    # integrate differential equation with exponential euler method
    self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')

    # Initialize a Variable:
    # - if non-batching mode, batch axis of V is None
    # - if batching mode,     batch axis of V is 0
    self.V = bp.init.variable_(bm.zeros, self.size, self.mode)
    self.V[:] = self.V_rest
    self.spike = bp.init.variable_(bm.zeros, self.size, self.mode)

  def reset_state(self, batch_size=None):
    self.V.value = bp.init.variable_(bm.ones, self.size, batch_size) * self.V_rest
    self.spike.value = bp.init.variable_(bm.zeros, self.size, batch_size)

  def update(self, s, x):
    t, dt = s.get('t'), s.get('dt', bm.dt)
    V = self.integral(self.V, t, x, dt=dt)
    # replace non-differential heaviside function
    # with a surrogate gradient function
    spike = self.f_surrogate(V - self.V_th)
    # reset membrane potential
    self.V.value = (1. - spike) * V + spike * self.V_rest
    self.spike.value = spike
    return spike

Model composition#

The LIF model we have defined above can be recursively composed to construct networks in brain simulation and brain-inspired computing.

The following code snippet utilizes the LIF model to build an E/I balanced network EINet, which is a classical network model in brain simulation.

class EINet(bp.DynamicalSystem):
  def __init__(self, num_exc, num_inh):
    super().__init__()
    self.E = LIF(num_exc, V_rest=-55, V_th=-50., tau=20.)
    self.I = LIF(num_inh, V_rest=-55, V_th=-50., tau=20.)
    self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(0.02),
                                       g_max=1.62, tau=5., output=None)
    self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(0.02),
                                       g_max=1.62, tau=5., output=None)
    self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(0.02),
                                       g_max=-9.0, tau=10., output=None)
    self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(0.02),
                                       g_max=-9.0, tau=10., output=None)

  def update(self, s, x):
    # x is the background input
    e2e = self.E2E(s)
    e2i = self.E2I(s)
    i2e = self.I2E(s)
    i2i = self.I2I(s)
    self.E(s, e2e + i2e + x)
    self.I(s, e2i + i2i + x)

with bm.environment(mode=bm.nonbatching_mode):
  net1 = EINet(3200, 800)

Moreover, our LIF model can also be used in brain-inspired computing scenario. The following AINet uses the LIF model to construct a model for AI training.

# This network can be used in AI applications

class AINet(bp.DynamicalSystem):
  def __init__(self, sizes):
    super().__init__()
    self.neu1 = LIF(sizes[0])
    self.syn1 = bp.layers.Dense(sizes[0], sizes[1])
    self.neu2 = LIF(sizes[1])
    self.syn2 = bp.layers.Dense(sizes[1], sizes[2])
    self.neu3 = LIF(sizes[2])

  def update(self, s, x):
    x = self.neu1(s, x)
    x = self.syn1(s, x)
    x = self.neu2(s, x)
    x = self.syn2(s, x)
    x = self.neu3(s, x)
    return x

with bm.environment(mode=bm.training_mode):
  net2 = AINet([100, 50, 10])

How to run DynamicalSystem?#

As we have stated above that DynamicalSystem only defines the updating rule at single time step, to run a DynamicalSystem instance over time, we need a for loop mechanism.

1. brainpy.math.for_loop#

for_loop is a structural control flow API which runs a function with the looping over the inputs. Moreover, this API just-in-time compile the looping process into the machine code.

Suppose we have 200 time steps with the step size of 0.1, we can run the model with:

with bm.environment(dt=0.1):
  # construct a set of shared argument with the given time steps
  shared = bm.shared_args_over_time(num_step=200)
  # construct the inputs with shape of (time, batch, feature)
  currents = bm.random.rand(200, 10, 100)

  # run the model
  net2.reset_state(batch_size=10)
  out = bm.for_loop(net2, (shared, currents))

out.shape
(200, 10, 10)

2. brainpy.DSRunner#

Another way to run the model in BrainPy is using the structural running object DSRunner and DSTrainer. They provide more flexible way to monitoring the variables in a DynamicalSystem. The details users should refer to the DSRunner tutorial.

with bm.environment(dt=0.1):
  runner = bp.DSRunner(net1, monitors={'E.spike': net1.E.spike, 'I.sike': net1.I.spike})
  runner.run(inputs=bm.ones(1000) * 20.)

bp.visualize.raster_plot(runner.mon['ts'], runner.mon['E.spike'])
_images/c4bf48caab8e8c38794abedf44b2754710ac4e0c1f7e3927157227ab899b16d4.png

Math Foundation#

brainpy.math.Variable#

@Chaoming Wang @Xiaoyu Chen

In BrainPy, the JIT compilation for class objects relies on brainpy.math.Variable. In this section, we are going to understand:

  • What is a Variable?

  • What are the subtypes of Variable?

  • How to update a Variable?

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.1'

brainpy.math.Variable#

brainpy.math.Variable is a pointer referring to a JAX Array. It stores an array as its value. The data in a Variable can be changed during our object-oriented JIT compilation. If an array is labeled as a Variable, it means that it is a dynamical variable that changes over time.

Arrays that are not marked as Variables will be JIT compiled as static data. Modifications of these arrays will be invalid or cause an error.

  • Creating a Variable

Passing a tensor into the brainpy.math.Variable creates a Variable. For example:

b1 = bm.random.random(5)
b1
Array([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)
b2 = bm.Variable(b1)
b2
Variable([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)
  • Accessing the value in a Variable

The data in a Variable can be obtained through .value.

b2.value
Array([0.70490587, 0.9825947 , 0.79977   , 0.21864283, 0.70959914],      dtype=float32)
(b2.value == b1).all()
Array(True, dtype=bool)
  • Supported operations on Variables

Variables support almost all the operations for a JAX array.

b2 + 1.
Array([1.7049059, 1.9825947, 1.79977  , 1.2186428, 1.7095991], dtype=float32)
b2 ** 2
Array([0.49689227, 0.9654924 , 0.63963205, 0.04780469, 0.5035309 ],      dtype=float32)
bm.floor(b2)
Array([0., 0., 0., 0., 0.], dtype=float32)

Subtypes of Variable#

brainpy.math.Variable has several subtypes, including brainpy.math.TrainVar, brainpy.math.Parameter, and brainpy.math.RandomState. Subtypes can also be customized and extended by users.

1. TrainVar#

brainpy.math.TrainVar is a trainable variable and a subclass of brainpy.math.Variable. Usually, the trainable variables are meant to require their gradients and compute the corresponding update values. However, users can also use TrainVar for other purposes.

b = bm.random.rand(4)

b
Array([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)
bm.TrainVar(b)
TrainVar([0.39813817, 0.2902342 , 0.0428251 , 0.7002579 ], dtype=float32)
2. Parameter#

brainpy.math.Parameter is to label a dynamically changed parameter. It is also a subclass of brainpy.math.Variable. The advantage of using Parameter rather than Variable is that it can be easily retrieved by the Collector.subset method.

b = bm.random.rand(1)

b
Array([0.47972953], dtype=float32)
bm.Parameter(b)
Parameter([0.47972953], dtype=float32)
3. RandomState#

brainpy.math.random.RandomState is also a subclass of brainpy.math.Variable. RandomState must store the dynamically changed key information (see JAX random number designs). Every time after a RandomState performs a random sampling, the “key” will change. Therefore, it is worthy to label a RandomState as the Variable.

state = bm.random.RandomState(1234)

state
RandomState(key=([   0, 1234], dtype=uint32))
# perform a "random" sampling 
state.random(1)

state  # the value changed
RandomState(key=([2113592192, 1902136347], dtype=uint32))
# perform a "sample" sampling 
state.sample(1)

state  # the value changed too
RandomState(key=([1076515368, 3893328283], dtype=uint32))

Every instance of RandomState can create a new seed from the current seed with .split_key().

state.split_key()
Array([3028232624,  826525938], dtype=uint32)

It can also create multiple seeds from the current seed with .split_keys(n). This is used internally by pmap and vmap to ensure that random numbers are different in parallel threads.

state.split_keys(2)
Array([[4198471980, 1111166693],
       [1457783592, 2493283834]], dtype=uint32)
state.split_keys(5)
Array([[3244149147, 2659778815],
       [2548793527, 3057026599],
       [ 874320145, 4142002431],
       [3368470122, 3462971882],
       [1756854521, 1662729797]], dtype=uint32)

There is a default RandomState in brainpy.math.random module: DEFAULT.

bm.random.DEFAULT
RandomState(key=([1682297581, 3751629511], dtype=uint32))

The inherent random methods like randint(), rand(), shuffle(), etc. are using this DEFAULT state. If you try to change the default RandomState, please use seed() method.

bm.random.seed(654321)

bm.random.DEFAULT
RandomState(key=([     0, 654321], dtype=uint32))

In-place updating#

In BrainPy, the transformations (like JIT) usually need to update variables or arrays in-place. In-place updating does not change the reference pointing to the variable while changing the data stored in the variable.

For example, here we have a variable a.

a = bm.Variable(bm.zeros(5))

The ids of the variable and the data stored in the variable are:

id_of_a = id(a)
id_of_data = id(a.value)

assert id_of_a != id_of_data

print('id(a)       = ', id_of_a)
print('id(a.value) = ', id_of_data)
id(a)       =  2781947736704
id(a.value) =  2781965742144

In-place update (here we use [:]) does not change the pointer refered to the variable but changes its data:

a[:] = 1.

print('id(a)       = ', id(a))
print('id(a.value) = ', id(a.value))
id(a)       =  2781947736704
id(a.value) =  2781965752128
print('(id(a) == id_of_a)          =', id(a) == id_of_a)
print('(id(a.value) == id_of_data) =', id(a.value) == id_of_data)
(id(a) == id_of_a)          = True
(id(a.value) == id_of_data) = False

However, once you do not use in-place operators to assign data, the id that the variable a refers to will change. This will cause serious errors when using transformations in BrainPy.

a = 10.

print('id(a) = ', id(a))
print('(id(a) == id_of_a) =', id(a) == id_of_a)
id(a) =  2781946941520
(id(a) == id_of_a) = False
The following in-place operators are not limited to ``brainpy.math.Variable`` and its subclasses. They can also apply to ``brainpy.math.Array``. 

Here, we list several commonly used in-place operators.

v = bm.Variable(bm.arange(10))
old_id = id(v)

def check_no_change(new_v):
    assert id(new_v) == old_id, 'Variable has been changed.'
1. Indexing and slicing#

Indexing and slicing are the two most commonly used operators. The details of indexing and slicing are in Array Objects Indexing.

Indexing: v[i] = a or v[(1, 3)] = c (index multiple values)

v[0] = 1

check_no_change(v)

Slicing: v[i:j] = b

v[1: 2] = 1

check_no_change(v)

Slicing all values: v[:] = d, v[...] = e

v[:] = 0

check_no_change(v)
v[...] = bm.arange(10)

check_no_change(v)
2. Augmented assignment#

All augmented assignment are in-place operations, which include

  • add: +=

  • subtract: -=

  • divide: /=

  • multiply: *=

  • floor divide: //=

  • modulo: %=

  • power: **=

  • and: &=

  • or: |=

  • xor: ^=

  • left shift: <<=

  • right shift: >>=

v += 1

check_no_change(v)
v *= 2

check_no_change(v)
v |= bm.random.randint(0, 2, 10)

check_no_change(v)
v **= 2

check_no_change(v)
v >>= 2

check_no_change(v)
3. .value assignment#

Another way to in-place update a variable is to assign new data to .value. This operation is very safe, because it will check whether the type and shape of the new data are consistent with the current ones.

v.value = bm.arange(10)

check_no_change(v)
try:
    v.value = bm.asarray(1.)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
    v.value = bm.random.random(10)
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.
4. .update() method#

Actually, the .value assignment is the same operation as the .update() method. Users who want a safe assignment can choose this method too.

v.update(bm.random.randint(0, 20, size=10))
try:
    v.update(bm.asarray(1.))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The shape of the original data is (10,), while we got () with batch_axis=None.
try:
    v.update(bm.random.random(10))
except Exception as e:
    print(type(e), e)
<class 'brainpy.errors.MathError'> The dtype of the original data is int32, while we got float32.

Control Flows for JIT compilation#

@Chaoming Wang

Control flow is the core of a program, because it defines the order in which the program’s code executes. The control flow of Python is regulated by conditional statements, loops, and function calls.

Python has two types of control structures:

  • Selection: used for decisions and branching.

  • Repetition: used for looping, i.e., repeating a piece of code multiple times.

In this section, we are going to talk about how to build effective control flows in the context of JIT compilation.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

1. Selection#

In Python, the selection statements are also known as Decision control statements or branching statements. The selection statement allows a program to test several conditions and execute instructions based on which condition is true. The commonly used control statements include:

  • if-else

  • nested if

  • if-elif-else

Non-Variable-based control statements#

Actually, BrainPy (based on JAX) allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances. For example,

class OddEven(bp.BrainPyObject):
    def __init__(self, type_=1):
        super(OddEven, self).__init__()
        self.type_ = type_
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.type_ == 1:
            self.a += 1
        elif self.type_ == 2:
            self.a -= 1
        else:
            raise ValueError(f'Unknown type: {self.type_}')
        return self.a

In the above example, the target statement in if (statement) syntax relies on a scalar, which is not an instance of brainpy.math.Variable. In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with JIT compilation.

model = bm.jit(OddEven(type_=1))

model()
Variable([1.], dtype=float32)
model = bm.jit(OddEven(type_=2))

model()
Variable([-1.], dtype=float32)
try:
    model = bm.jit(OddEven(type_=3))
    model()
except ValueError as e:
    print(f"ValueError: {str(e)}")
ValueError: Unknown type: 3
Variable-based control statements#

However, if the statement target in a if ... else ... syntax relies on instances of brainpy.math.Variable, writing Pythonic control flows will cause errors when using JIT compilation.

class OddEvenCauseError(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCauseError, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        if self.rand < 0.5:  self.a += 1
        else:  self.a -= 1
        return self.a
wrong_model = bm.jit(OddEvenCauseError())

try:
    wrong_model()
except Exception as e:
    print(f"{e.__class__.__name__}: {str(e)}")
ConcretizationTypeError: This problem may be caused by several ways:
1. Your if-else conditional statement relies on instances of brainpy.math.Variable. 
2. Your if-else conditional statement relies on functional arguments which do not set in "static_argnames" when applying JIT compilation. More details please see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
3. The static variables which set in the "static_argnames" are provided as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].

To perform conditional statement on Variable instances, we need structural control flow syntax. Specifically, BrainPy provides several options (based on JAX):

  • brainpy.math.where: return element-wise conditional comparison results.

  • brainpy.math.ifelse: Conditional statements of if-else, or if-elif-else, … for a scalar-typed value.

brainpy.math.where#

where(condition, x, y) function returns elements chosen from x or y depending on condition. It can perform well on scalars, vectors, and high-dimensional arrays.

a = 1.
bm.where(a < 0, 0., 1.)
Array(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
Array([0., 0., 1., 1., 0.], dtype=float32, weak_type=True)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
Array([[0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32, weak_type=True)

For the above example, we can rewrite it by using where syntax as:

class OddEvenWhere(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenWhere, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.where(self.rand < 0.5, 1., -1.)
        return self.a
model = bm.jit(OddEvenWhere())
model()
Variable([1.], dtype=float32)
brainpy.math.ifelse#

Based on JAX’s control flow syntax jax.lax.cond, BrainPy provides a more general conditional statement enabling multiple branching.

In its simplest case, brainpy.math.ifelse(condition, branches, operands, dyn_vars=None) is equivalent to:

def ifelse(condition, branches, operands, dyn_vars=None):
  true_fun, false_fun = branches
  if condition:
    return true_fun(operands)
  else:
    return false_fun(operands)

Based on this function, we can rewrite the above example by using cond syntax as:

class OddEvenCond(bp.BrainPyObject):
    def __init__(self):
        super(OddEvenCond, self).__init__()
        self.rand = bm.Variable(bm.random.random(1))
        self.a = bm.Variable(bm.zeros(1))

    def __call__(self):
        self.a += bm.ifelse(self.rand[0] < 0.5,
                            [lambda _: 1., lambda _: -1.])
        return self.a
model = bm.jit(OddEvenCond())
model()
Variable([-1.], dtype=float32)

If you want to write control flows with multiple branchings, brainpy.math.ifelse(conditions, branches, operands, dyn_vars=None) can also help you accomplish this easily. Actually, multiple branching case is equivalent to:

def ifelse(conditions, branches, operands, dyn_vars=None):
  pred1, pred2, ... = conditions
  func1, func2, ..., funcN = branches
  if pred1:
    return func1(operands)
  elif pred2:
    return func2(operands)
  ...
  else:
    return funcN(operands)

For example, if you have the following code:

def f(a):
  if a > 10:
    return 1.
  elif a > 5:
    return 2.
  elif a > 0:
    return 3.
  elif a > -5:
    return 4.
  else:
    return 5.

It can be expressed as:

def f(a):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[1., 2., 3., 4., 5.])
f(11.)
Array(1., dtype=float32, weak_type=True)
f(6.)
Array(2., dtype=float32, weak_type=True)
f(1.)
Array(3., dtype=float32, weak_type=True)
f(-4.)
Array(4., dtype=float32, weak_type=True)
f(-6.)
Array(5., dtype=float32, weak_type=True)

A more complex example is:

def f2(a, x):
  return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
                   branches=[lambda x: x*2,
                             2.,
                             lambda x: x**2 -1,
                             lambda x: x - 4.,
                             5.],
                   operands=x)
f2(11, 1.)
Array(2., dtype=float32, weak_type=True)
f2(6, 1.)
Array(2., dtype=float32, weak_type=True)
f2(1, 1.)
Array(0., dtype=float32, weak_type=True)
f2(-4, 1.)
Array(-3., dtype=float32, weak_type=True)
f2(-6, 1.)
Array(5., dtype=float32, weak_type=True)

If instances of brainpy.math.Variable are used in branching functions, you can declare them in the dyn_vars argument.

a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x):  a.value += 1
def false_f(x): b.value -= 1

bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])

print('a:', a)
print('b:', b)
a: Variable([1., 1.], dtype=float32)
b: Variable([0., 0.], dtype=float32)

2. Repetition#

A repetition statement is used to repeat a group(block) of programming instructions.

In Python, we generally have two loops/repetitive statements:

  • for loop: Execute a set of statements once for each item in a sequence.

  • while loop: Execute a block of statements repeatedly until a given condition is satisfied.

Pythonic loop syntax#

Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example,

class LoopSimple(bp.BrainPyObject):
    def __init__(self):
        super(LoopSimple, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = bm.Variable(rng.random(1000))
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        for s in self.seq:
            self.res += s
        return self.res.value
import time

def measure_time(f, return_res=False, verbose=True):
    t0 = time.time()
    r = f()
    t1 = time.time()
    if verbose:
        print(f'Result: {r}, Time: {t1 - t0}')
    return r if return_res else None
model = bm.jit(LoopSimple())

# First time will trigger compilation
measure_time(model)
Result: [501.74664], Time: 0.9443738460540771
# Second running
measure_time(model)
Result: [1003.4931], Time: 0.0

When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax.

JAX has provided several important loop syntax, including:

BrainPy also provides its own loop syntax, which is especially suitable for the cases where users are using brainpy.math.Variable. Specifically, they are:

  • brainpy.math.for_loop

  • brainpy.math.while_loop

In this section, we only talk about how to use our provided loop functions.

brainpy.math.for_loop()#

brainpy.math.make_loop() is used to generate a for-loop function when you use Variable.

Suppose that you are using several JaxArrays (grouped as dyn_vars) to implement your body function “body_fun”, and you want to gather the history values of several of them (grouped as out_vars). Sometimes the body function already returns something, and you also want to gather the returned values. With the Python syntax, it can be realized as


def for_loop_function(body_fun, dyn_vars, xs):
  ys = []
  for x in xs:
    # 'dyn_vars' are updated in 'body_fun()'
    results = body_fun(x)
    ys.append(results)
  return ys

In BrainPy, you can define this logic using brainpy.math.for_loop():

import brainpy.math

hist_of_out_vars = brainpy.math.for_loop(body_fun, dyn_vars, operands)

For the above example, we can rewrite it by using brainpy.math.for_loop as:

class LoopStruct(bp.BrainPyObject):
    def __init__(self):
        super(LoopStruct, self).__init__()
        rng = bm.random.RandomState(123)
        self.seq = rng.random(1000)
        self.res = bm.Variable(bm.zeros(1))

    def __call__(self):
        def add(s):
          self.res += s
          return self.res.value

        return bm.for_loop(body_fun=add, dyn_vars=[self.res], operands=self.seq)
model = bm.jit(LoopStruct())

r = measure_time(model, verbose=False, return_res=True)
r.shape
(1000, 1)

In essence, body_fun defines the one-step updating rule of how variables are updated. All returns of body_fun will be gathered as the history values. dyn_vars defines all dynamical variables used in the body_fun. operands specified the inputs of the body_fun. It will be looped over the fist axis.

brainpy.math.while_loop()#

brainpy.math.while_loop() is used to generate a while-loop function when you use Varible. It supports the following loop logic:


while condition:
    statements

When using brainpy.math.while_loop() , condition should be wrapped as a cond_fun function which returns a boolean value, and statements should be packed as a body_fun function which receives the old values at the latest step and returns the updated values at the current step:


while cond_fun(x):
    x = body_fun(x)

Note the difference between brainpy.math.for_loop and brainpy.math.while_loop:

  1. The returns of brainpy.math.for_loop are the values to be gathered as the history values. While the returns of brainpy.math.while_loop should be the same shape and type with its inputs, because they are represented as the updated values.

  2. brainpy.math.for_loop can receive anything without explicit requirements of returns. But, brainpy.math.while_loop should return what it receives.

A concreate example of brainpy.math.while_loop is as the follows:

i = bm.Variable(bm.zeros(1))
counter = bm.Variable(bm.zeros(1))

def cond_f():
    return i[0] < 10

def body_f():
    i.value += 1.
    counter.value += i

bm.while_loop(body_f, cond_f, dyn_vars=[i, counter], operands=())
()

In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays i and counter.

counter
Variable([55.], dtype=float32)
i
Variable([10.], dtype=float32)

Or, similarly,

i = bm.Variable(bm.zeros(1))

def cond_f(counter):
    return i[0] < 10

def body_f(counter):
    i.value += 1.
    return counter + i[0]

bm.while_loop(body_f, cond_f, dyn_vars=[i], operands=(1., ))
(Array(56., dtype=float32),)

Random Number Generation for JIT Compilation#

Chaoming Wang

Although brainpy.math.random is designed to be seamlessly compatible with numpy.random, there are still some differences under the context of JIT compilation.

In this section, we are going to talk about how to program a JIT-compatible code with brainpy.math.random.

import brainpy as bp
import brainpy.math as bm
import numpy as np

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Using bm.random outside functions to JIT#

Using brainpy.math.random outside of functions to JIT is the same as using numpy.random.

This usage corresponds to the cases that generating random data for further processing. For example,

bm.random.rand(10)
Array([0.7222161 , 0.2043277 , 0.59838593, 0.255252  , 0.14954388,
       0.05150986, 0.214692  , 0.03857851, 0.81150043, 0.4669956 ],      dtype=float32)
np.random.rand(10)
array([0.60626677, 0.69464463, 0.81361761, 0.1583908 , 0.50378113,
       0.17677626, 0.7507633 , 0.75699064, 0.33320096, 0.38958635])

When you are using API functions in brainpy.math.random, actually you are calling functions in a default RandomState. Similarly, numpy.random also has a default RandomState. Calling a random function in numpy.random module corresponds to calling the random function in this default NumPy RandomState.

bm.random.DEFAULT
RandomState(key=([3014124236, 2009892527], dtype=uint32))

Using bm.random inside a function to JIT#

If you are using random sampling in a JIT function, there are things you need to pay attention to. Otherwise, the error is likely to raise.

As I have stated above, brainpy.math.random functions are using the default RandomState. A RandomState is an instance of brainpy Variable, denoting that it has values to change after calling any its built-in random function. What’s changing is the key of a RandomState. For instance,

bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([ 873106783, 4065854088], dtype=uint32))
bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([3526960574,  230845945], dtype=uint32))

Therefore, if you do not specify this DEFAULT RandomState you are using, repeatedly calling random functions in brainpy.math.random module will not get what you want, because its key cannot be updated. For instance,

@bm.jit
def get_data():
    return bm.random.random(2)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)

A correct way is explicitly declaring you are using this DEFAULT variable in the JIT transformation.

bm.random.seed()
from functools import partial

@partial(bm.jit, dyn_vars=(bm.random.DEFAULT, ))
def get_data_v2():
    return bm.random.random(2)
get_data_v2()
Array([0.38541543, 0.5843446 ], dtype=float32)
get_data_v2()
Array([0.85543776, 0.36957836], dtype=float32)

Or, declare the function as a BrainPyObject, then use jit().

@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def get_data_v3():
    return bm.random.random(2)
get_data_v3()
Array([0.31096482, 0.7970413 ], dtype=float32)
get_data_v3()
Array([0.26830554, 0.15947664], dtype=float32)

Using RandomState for objects to JIT#

Another way I recommend is using instances of RandomState for objects to JIT. For example, you can initialize a RandomState in the __init__() function, then using the initialized RandomState anywhere.

class MyOb(bp.BrainPyObject):
    def __init__(self):
        super().__init__()
        self.rng = bm.random.RandomState(123)

    def __call__(self):
        size = (50, 100)
        u = self.rng.random(size)
        v = self.rng.uniform(size=size)
        z = bm.sqrt(-2 * bm.log(u)) * bm.cos(2 * bm.pi * v)
        return z
ob = bm.jit(MyOb())
ob()
Array([[ 1.3595979 , -1.3462192 ,  0.7149456 , ...,  1.4283268 ,
        -1.1362855 , -0.18378317],
       [-0.26401126, -1.6798397 , -0.8422355 , ...,  1.0795223 ,
         0.41247413, -0.955116  ],
       [ 0.6234829 , -0.44811824, -0.03835859, ..., -2.5203867 ,
        -0.02713326,  1.6490041 ],
       ...,
       [-0.9861029 ,  0.36676335, -0.31499916, ...,  1.526808  ,
        -0.7946268 , -0.86713606],
       [-1.7008592 , -0.05957834, -0.5677447 , ..., -0.04765594,
         0.574145  , -0.11830498],
       [-0.22663854, -1.8517947 , -1.3546717 , ...,  1.2332705 ,
        -0.79247886, -1.9352005 ]], dtype=float32)

Note that any Variable instance which can be directly accessed by self. is able to be automatically found by brainpy’s JIT transformation functions. Therefore, in this case, we do not need to pass the rng into the dyn_vars in bm.jit() function.

Model Building#

Using Built-in Models#

@Tianqiu Zhang @Chaoming Wang

BrainPy enables modularity programming and easy model debugging. To build a complex brain dynamics model, you just need to group its building blocks. In this section, we are going to talk about what building blocks we provide, and how to use these building blocks.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.1'

Initializing a neuron model#

All neuron models implemented in brainpy are subclasses of brainpy.dyn.NeuGroup. The initialization of a neuron model just needs to provide the geometry size of neurons in a population group.

hh = bp.neurons.HH(size=1)  # only 1 neuron

hh = bp.neurons.HH(size=10)  # 10 neurons in a group

hh = bp.neurons.HH(size=(10, 10))  # a grid of (10, 10) neurons in a group

hh = bp.neurons.HH(size=(5, 4, 2))  # a column of (5, 4, 2) neurons in a group

Generally speaking, there are two types of arguments can be set by users:

  • parameters: the model parameters, like gNa refers to the maximum conductance of sodium channel in the brainpy.dyn.HH model.

  • variables: the model variables, like V refers to the membrane potential of a neuron model.

In default, model parameters are homogeneous, which are just scalar values.

hh = bp.neurons.HH(5)  # there are five neurons in this group

hh.gNa
120.0

However, neuron models support heterogeneous parameters when performing computations in a neuron group. One can initialize heterogeneous parameters by several ways.

1. Array

Users can directly provide an array as the parameter.

hh = bp.neurons.HH(5, gNa=bm.random.uniform(110, 130, size=5))

hh.gNa
Array([127.70759, 125.13152, 112.63894, 127.90401, 129.41827], dtype=float32)

2. Initializer

BrainPy provides wonderful supports on initializations. One can provide an initializer to the parameter to instruct the model initialize heterogeneous parameters.

hh = bp.neurons.HH(5, ENa=bp.init.OneInit(50.))

hh.ENa
Array([50., 50., 50., 50., 50.], dtype=float32)

3. Callable function

You can also directly provide a callable function which receive a shape argument.

hh = bp.neurons.HH(5, ENa=lambda shape: bm.random.uniform(40, 60, shape))

hh.ENa
Array([57.047512, 53.037655, 59.74895 , 50.8206  , 44.256607], dtype=float32)

Here, let’s see how the heterogeneous parameters influence our model simulation.

# we create 3 neurons in a group. Each neuron has a unique "gNa"

model = bp.neurons.HH(3, gNa=bp.init.Uniform(min_val=100, max_val=140))
runner = bp.dyn.DSRunner(model, monitors=['V'], inputs=['input', 5.])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 2], show=True)
_images/f360243c839b5ec2e91853d33e1df876dcbab498164c8d22e98c9926824ea1eb.png

Similarly, the setting of the initial values of a variable can also be realized through the above three ways: Array, Initializer, and Callable function. For example,

hh = bp.neurons.HH(
   3,
   V_initializer=bp.init.Uniform(-80., -60.),  # Initializer
   m_initializer=lambda shape: bm.random.random(shape),  # function
   h_initializer=bm.random.random(3),  # Array
)
print('V: ', hh.V)
print('m: ', hh.m)
print('h: ', hh.h)
V:  Variable([-72.55422 , -61.628696, -71.0226  ], dtype=float32)
m:  Variable([0.7881355 , 0.40693295, 0.7243513 ], dtype=float32)
h:  Variable([0.47316658, 0.15884387, 0.6759169 ], dtype=float32)

Initializing a synapse model#

Initializing a synapse model needs to provide its pre-synaptic group (pre), post-synaptic group (post) and the connection method between them (conn). The below is an example to create an Exponential synapse model:

neu = bp.neurons.LIF(10)

# here we create a synaptic projection within a population
syn = bp.synapses.Exponential(pre=neu, post=neu, conn=bp.conn.All2All())

BrainPy’s build-in synapse models support heterogeneous synaptic weights and delay steps by using Array, Initializer and Callable function. For example,

syn = bp.synapses.Exponential(neu, neu, bp.conn.FixedProb(prob=0.1),
                              g_max=bp.init.Uniform(min_val=0.1, max_val=1.),
                              delay_step=lambda shape: bm.random.randint(10, 30, shape))
syn.g_max
Array([0.5255966 , 0.13250259, 0.49933627, 0.9400071 , 0.56140935,
       0.7105977 , 0.89582247, 0.63783807, 0.97180253, 0.2137514 ],      dtype=float32)
syn.delay_step
Array([10, 14, 22, 18, 22, 23, 25, 21, 28, 10], dtype=int32)

However, in BrainPy, the built-in synapse models only support homogenous synaptic parameters, like the time constant \(\tau\). Users can customize their synaptic models when they want heterogeneous synaptic parameters.

Similar, the synaptic variables can be initialized heterogeneously by using Array, Initializer, and Callable functions.

Changing model parameters during simulation#

In BrainPy, all the dynamically changed variables (no matter it is changed inside or outside the jitted function) should be marked as brainpy.math.Variable. BrainPy’s built-in models also support modifying model parameters during simulation.

For example, if you want to fix the gNa in the first 100 ms simulation, and then try to decrease its value in the following simulations. In this case, we can provide the gNa as an instance of brainpy.math.Variable when initializing the model.

hh = bp.neurons.HH(5, gNa=bm.Variable(bm.asarray([120.])))

runner = bp.dyn.DSRunner(hh, monitors=['V'], inputs=['input', 5.])
# the first running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/654af7eca8619014760ae1ae8e053a88cb9505866a00cff505da5aed4a9be99a.png
# change the gNa first
hh.gNa[:] = 100.

# the second running
runner.run(100.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/ec99cb4fdde081ad2a11574380689fa49ff1c84aa8f5d8d3859c6d696556185c.png

Examples of using built-in models#

Here we show users how to simulate a famous neuron models: The Morris-Lecar neuron model, which is a two-dimensional “reduced” excitation model applicable to systems having two non-inactivating voltage-sensitive conductances.

group = bp.neurons.MorrisLecar(1)

Then users can utilize various tools provided by BrainPy to easily simulate the Morris-Lecar neuron model. Here we are not going to dive into details so please read the corresponding tutorials if you want to learn more.

runner = bp.dyn.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.))
runner.run(1000)

fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True)
_images/3d9f3f4c577332d79981007405862971471997033a950869a0680ce0a9bbf1fa.png

Next we will also give users an intuitive understanding about building a network composed of different neurons and synapses model. Users can simply initialize these models as below and pass into brainpy.dyn.Network.

neu1 = bp.neurons.HH(1)
neu2 = bp.neurons.HH(1)
syn1 = bp.synapses.AMPA(neu1, neu2, bp.connect.All2All())
net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)

By selecting proper runner, users can simulate the network efficiently and plot the simulation results.

runner = bp.dyn.DSRunner(net,
                         inputs=(neu1.input, 6.),
                         monitors=['pre.V', 'post.V', 'syn.g'])
runner.run(150.)

import matplotlib.pyplot as plt

fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
plt.legend()

fig.add_subplot(gs[1, 0])
plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
plt.legend()
plt.show()
_images/e4f2621509b02e1c2df8f18a0f63008ca8a2825db9c82516500591a623de0568.png

Building Conductance-based Neuron Models#

@Xiaoyu Chen

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.1'

There are basically two types of neuron models: conductance-based models and simplified models. In conductance-based models, a single neuron can be regarded as a electric circuit, where the membrane is a capacitor, ion channels are conductors, and ion gradients are batteries. The neuronal activity is captured by the current flows through those ion channels. Sometimes there is an external input to this neuron, which can also be included in the equivalent circuit (see the figure below which shows potassium channels, sodium channels and leaky channels).

On the other hand, simplified models do not care about the physiological features of neurons but mainly focus on how to reproduce the exact spike timing. Therefore, they are more simplified and maybe not biologically explicable.

BrainPy provides a large volume of predefined neuron models including conductance-based and simplified models for ease of use. In this section, we will only talk about how to build conductance-based models by ion channels. Users please refer to Customizing Your Neuron Models for more information.

Building an ion channel#

As we have known, ion channels are crucial for conductance-based neuron models. So how do we model an ion channel? Let’s take a look at the potassium channel for instance.

The diagram above shows how a potassium channel is changed to an electric circuit. By this, we have the differential equation:

\[\begin{split} \begin{align} c_\mathrm{M} \frac{\mathrm{d}V_\mathrm{M}}{\mathrm{d}t} &= \frac{E_\mathrm{K} - V_\mathrm{M}}{R_\mathrm{K}} \\ &= g_\mathrm{K}(E_\mathrm{K} - V_\mathrm{M}), \end{align} \end{split}\]

in which \(c_\mathrm{M}\) is the membrane capacitance, \(\mathrm{d}V_\mathrm{M}\) is the membrane potential, \(E_\mathrm{K}\) is the equilibrium potential of potassium ions, and \(R_\mathrm{K}\) (\(g_\mathrm{K}\)) refers to the resistance (conductance) of the potassium channel. We define currents from inside to outside as the positive direction.

In the equation above, the conductance of potassium channels \(g_\mathrm{K}\) does not remain a constant, but changes according to the membrane potential, by which the channel is categorized as voltage-gated ion channels. If we want to build an ion channel model, we should figure out how the conductance of the ion channel changes with membrane potential.

Fortunately, there has been a lot of work addressing this issue to formulate analytical expressions. For example, the conductance of one typical potassium channel can be written as:

\[\begin{split} \begin{align} g_\mathrm{K} &= \bar{g}_\mathrm{K} n^4, \\ \frac{\mathrm{d}n}{\mathrm{d}t} &= \phi [\alpha_n(V)(1-n) - \beta_n(V)n], \end{align} \end{split}\]

in which \(\bar{g}_\mathrm{K}\) refers to the maximal conductance and \(n\), also named the gating variable, refers to the probability (proportion) of potassium channels to open. \(\phi\) is a parameter showing the effects of temperature. In the differential equation of \(n\), there are two parameters, \(\alpha_n(V)\) and \(\beta_n(V)\), that change with membrane potential:

\[\begin{split} \begin{align} \alpha_n(V) &= \frac{0.01(V+55)}{1 - \exp(-\frac{V+55}{10})}, \\ \beta_n(V) &= 0.125 \exp\left(-\frac{V+65}{80}\right). \end{align} \end{split}\]

Now we have learned the mathematical expression of the potassium channel. Next, we try to build this channel in BrainPy.

class IK(bp.Channel):
  def __init__(self, size, E=-77., g_max=36., phi=1., method='exp_auto'):
    super(IK, self).__init__(size)
    self.g_max = g_max
    self.E = E
    self.phi = phi

    self.n = bm.Variable(bm.zeros(size))  # variables should be packed with bm.Variable
    
    self.integral = bp.odeint(self.dn, method=method)

  def dn(self, n, t, V):
    alpha_n = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta_n = 0.125 * bm.exp(-(V + 65) / 80)
    return self.phi * (alpha_n * (1. - n) - beta_n * n)

  def update(self, tdi, V):
    self.n.value = self.integral(self.n, tdi.t, V, dt=tdi.dt)

  def current(self, V):
    return self.g_max * self.n ** 4 * (self.E - V)

Note that besides the initialzation and update function, another function named current() that computes the current flow through this channel must be implemented. Then this potassium channel model can be used as a building block for assembling a conductance-based neuron model.

Building a conductance-based neuron model with ion channels#

Instead of building a conductance-based model from scratch, we can utilize ion channel models as building blocks to assemble a neuron model in a modular and convenient way. Now let’s try to construct a Hodgkin-Hoxley (HH) model (jump to here for the complete mathematical expression of the HH model).

The HH neuron models the cuurent flows of potassium, sodium, and leaky channels. Besides the potassium channel that we implemented, we can import the other channel models from brainpy.dyn.channels:

from brainpy.dyn.channels import INa_HH1952, IL
# actually the potassium channel we implemented can also be found in this package as 'IK_HH1952'

Then we wrap these three channels into a single neuron model:

class HH(bp.CondNeuGroup):
  def __init__(self, size):
    super(HH, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
    self.IK = IK(size, E=-77., g_max=36.)
    self.INa = INa_HH1952(size, E=50., g_max=120.)
    self.IL = IL(size, E=-54.39, g_max=0.03)

Here the HH class should inherit the superclass bp.CondNeuGroup, which will automatically integrate the current flows by calling the current() function of each channel model to compute the neuronal activity when running a simulation.

Surprisingly, the model contruction is finished! Users do not need to implement the update function of the neuron model as CondNeuGroup has its own way to update variables (like the membrane potential V and spiking sequence spike) implicitly.

Now let’s run a simulation of this HH model to examine the changes of the inner variables.

First of all, we instantiate a neuron group with 1 HH neuron:

neu = HH(1)

Then we wrap the neuron group into a dynamical-system runner DSRunner for running a simulation:

runner = bp.DSRunner(
    neu, 
    monitors=['V', 'IK.n', 'INa.p', 'INa.q'], 
    inputs=('input', 6.)  # constant external inputs of 6 mA to all neurons
)

Then we run the simulation and visualize the result:

runner.run(200)  # the running time is 200 ms

import matplotlib.pyplot as plt

plt.plot(runner.mon['ts'], runner.mon['V'])
plt.xlabel('t (ms)')
plt.ylabel('V (mV)')

plt.show()
_images/3679da1001f8062bd892d1b03c90c54fd2d025a98f07601623a951b147386203.png

We can also visualize the changes of the gating variables of sodium and potassium channels:

plt.figure(figsize=(6, 2))
plt.plot(runner.mon['ts'], runner.mon['IK.n'], label='n')
plt.plot(runner.mon['ts'], runner.mon['INa.p'], label='m')
plt.plot(runner.mon['ts'], runner.mon['INa.q'], label='h')
plt.xlabel('t (ms)')
plt.legend()

plt.show()
_images/7151168a1fdcd7e77b4824cc55a7a193dafe6737fd66137591e6037983348aa2.png

By combining different ion channels, we can get different types of conductance-based neuron models easily and straightforwardly. To see all predifined channel models in BrainPy, please click here.

Building Synapse Models#

@Chaoming Wang

In BrainPy, synapse models can be created by several ways. In this section, we will talk about a structural building process with brainpy.dyn.TwoEndConn, which is used to create models with pre- and post-synaptic neuron groups. A synapse model is decomposed into several components in brainpy.dyn.TwoEndConn. In such a way, building a synapse model can follow a modular and composable programming interface. Fore more details of defining a general synapse model, please refer to brainpy.dyn.SynConn in Tutorial: Customizing Synapse Models.

import brainpy as bp

# bp.math.set_platform('cpu')

bp.__version__
'2.3.1'

Synaptic models with brainpy.dyn.TwoEndConn#

In general, brainpy.dyn.TwoEndConn is used to model synaptic models with the following form:

\[\begin{split} \frac{dg}{dt} = f_{\mathrm{dyn}}(g, t)& \,\, \to \, &\text{dyanmics of the synaptic conductance} \\ g_{\mathrm{max}} = f_{\mathrm{LTP}}(g_{\mathrm{max}}, t) & \,\, \to \, &\text{long-term plasticity on synaptic weights }\\ g = f_{\mathrm{STP}}(g, t) & \,\, \to \, &\text{short-term plasticity on synaptic conductance}\\ I_{\mathrm{post}} = f_{\mathrm{out}}(g_{\mathrm{max}} * g, t) & \,\, \to \, &\text{synaptic output onto post-synpatic neurons}\\ \end{split}\]

where each synapse model has its dynamical conductance \(g\), synaptic weight \(g_{\mathrm{max}}\), and

  • \(I_{\mathrm{post}}\) is the synaptic current onto the post-synaptic neurons,

  • \(f_{\mathrm{dyn}}\) is the function to compute synaptic dynamics,

  • \(f_{\mathrm{LTP}}\) is the function for computing synaptic long-term plasticity,

  • \(f_{\mathrm{STP}}\) is the function for computing synaptic short-term plasticity,

  • \(f_{\mathrm{out}}\) is the way to output synaptic currents on post-synaptic neurons.

Example 1: Exponential synapse model#

For a exponential synapse model,

\[\begin{split} \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}), \, (1) \\ I_{\mathrm{post}}(t) = g_{\mathrm{max}} * g * (V_{\mathrm{post}}(t)-E), \end{split}\]

where its \(f_{\mathrm{dyn}}\) is defined as equation (1), its \(f_{\mathrm{LTP}}\) and \(f_{\mathrm{STP}}\) is the identity function \(x = f(x)\), \(f_{\mathrm{out}}\) is defined as a conductance-based form with \((V_{\mathrm{post}}(t)-E)\).

Therefore, in BrainPy, we can define this model as the following form:

# a pre-synaptic neuron which generate spike at 1 ms, 11 ms, 21 ms.
pre = bp.neurons.SpikeTimeGroup(1, [1., 11., 21.], [0, 0, 0])

# a post-synaptic integrator which integrate synaptic inputs
post = bp.neurons.LeakyIntegrator(1)

# the synaptic model we want, whose output function is defined with `bp.synouts.COBA`
bp.synapses.Exponential(pre, post, bp.conn.All2All(),
                        output=bp.synouts.COBA(E=0.))
Exponential(name=Exponential0, mode=NonBatchingMode, 
            pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
            post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))

Similarly, an Exponential synapse model with the current-based output can be defined as:

bp.synapses.Exponential(pre, post, bp.conn.All2All(),
                        output=bp.synouts.CUBA())
Exponential(name=Exponential1, mode=NonBatchingMode, 
            pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
            post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))

Example 2: NMDA synapse model#

NMDA synapse model is different from other models, since its currents onto post-synaptic groups are regulated by magnesium. Specifically, the net NMDA receptor-mediated synaptic current is given by

\[ I_{\mathrm{post}} = g_{\mathrm{max}} \cdot g(t) \cdot (V(t)-E) \cdot g_{\infty} \]

where \(g_{\infty}\) represents the fraction of channels that are not blocked by magnesium.

\[ g_{\infty} = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} \]

Here \([{Mg}^{2+}]_{o}\) is the extracellular magnesium concentration, usually 1 mM.

In BrainPy, we provide this kind of magnesium-mediated synaptic output with brainpy.synouts.MgBlock. Therefore, a NMDA synapse can be defined with:

bp.synapses.NMDA(pre, post, bp.conn.All2All(),
                 output=bp.synouts.MgBlock(E=0., cc_Mg=1.2))
NMDA(name=NMDA0, mode=NonBatchingMode, 
     pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
     post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))

Example 3: Synapse models with short-term plasticity#

Short-term synaptic plasticity is ambitious in synapse dynamics. BrainPy provides brainpy.synplast.STD for short-term depression and brainpy.synplast.STP for general short-term plasticity. Short-term synaptic plasticity can be added onto most of synaptic models in BrainPy. For instance, here we define AMPA, GABA, and NMDA synapse models used in (Guoshi Li, et, al., 2017) [1].

  • [1] Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. “Unified thalamic model generates multiple distinct oscillations with state-dependent entrainment by stimulation.” PLoS computational biology 13.10 (2017): e1005797.

# AMPA synapse model with STD

bp.synapses.AMPA(pre, post, bp.conn.FixedProb(0.3),
                 stp=bp.synplast.STD(tau=700, U=0.07),
                 output=bp.synouts.COBA(E=0.),
                 alpha=0.94, beta=0.18, g_max=6e-3)
AMPA(name=AMPA0, mode=NonBatchingMode, 
     pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
     post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
# GABA synapse model with STD

bp.synapses.GABAa(pre, post, bp.conn.FixedProb(0.3),
                  stp=bp.synplast.STD(tau=700, U=0.07),
                  output=bp.synouts.COBA(E=-80),
                  alpha=10.5, beta=0.166, g_max=3e-3)
GABAa(name=GABAa0, mode=NonBatchingMode, 
      pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
      post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))
# NMDA synapse model with STD

bp.synapses.NMDA(pre, post, bp.conn.FixedProb(0.3),
                 stp=bp.synplast.STD(tau=700, U=0.07),
                 output=bp.synouts.MgBlock(E=0., cc_Mg=1.2))
NMDA(name=NMDA1, mode=NonBatchingMode, 
     pre=SpikeTimeGroup(name=SpikeTimeGroup0, mode=NonBatchingMode, size=(1,)), 
     post=LeakyIntegrator(name=LeakyIntegrator0, mode=NonBatchingMode, size=(1,)))

Example 4: synapse models with long-term plasticity#

TODO.

Building Network Models#

@Xiaoyu Chen @Chaoming Wang

In previous sections, it has been illustrated how to define neuron models by brainpy.dyn.NeuGroup and synapse models by brainpy.dyn.TwoEndConn. This section will introduce brainpy.dyn.Network, which is the base class used to build network models.

In essence, brainpy.dyn.Network is a container, whose function is to compose the individual elements. It is a subclass of a more general class: brainpy.dyn.Container.

In below, we take an excitation-inhibition (E-I) balanced network model as an example to illustrate how to compose the LIF neurons and Exponential synapses defined in previous tutorials to build a network.

import brainpy as bp

# bp.math.set_platform('cpu')

bp.__version__
'2.3.1'

Excitation-Inhibition (E-I) Balanced Network#

The E-I balanced network was first proposed to explain the irregular firing patterns of cortical neurons and comfirmed by experimental data. The network [1] we are going to implement consists of excitatory (E) neurons and inhibitory (I) neurons, the ratio of which is about 4 : 1. The biggest difference between excitatory and inhibitory neurons is the reversal potential - the reversal potential of inhibitory neurons is much lower than that of excitatory neurons. Besides, the membrane time constant of inhibitory neurons is longer than that of excitatory neurons, which indicates that inhibitory neurons have slower dynamics.

[1] Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98.

# BrianPy has some built-in canonical neuron and synapse models

LIF = bp.neurons.LIF
Exponential = bp.synapses.Exponential

Two ways to define network models#

There are several ways to define a Network model.

1. Defining a network as a class#

The first way to define a network model is like follows.

class EINet(bp.Network):
  def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
    super(EINet, self).__init__(**kwargs)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    E = LIF(num_exc, **pars, method=method)
    I = LIF(num_inh, **pars, method=method)
    E.V.value = bp.math.random.randn(num_exc) * 2 - 55.
    I.V.value = bp.math.random.randn(num_inh) * 2 - 55.

    # synapses
    w_e = 0.6  # excitatory synaptic weight
    w_i = 6.7  # inhibitory synaptic weight
    E_pars = dict(output=bp.synouts.COBA(E=0.), g_max=w_e, tau=5.)
    I_pars = dict(output=bp.synouts.COBA(E=-80.), g_max=w_i, tau=10.)
    
    # Neurons connect to each other randomly with a connection probability of 2%
    self.E2E = Exponential(E, E, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
    self.E2I = Exponential(E, I, bp.conn.FixedProb(prob=0.02), **E_pars, method=method)
    self.I2E = Exponential(I, E, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)
    self.I2I = Exponential(I, I, bp.conn.FixedProb(prob=0.02), **I_pars, method=method)

    self.E = E
    self.I = I

In an instance of brainpy.dyn.Network, all self. accessed elements can be gathered by the .nodes() function automatically.

EINet(8, 2).nodes(level=-1).subset(bp.DynamicalSystem)
{'EINet0': EINet(),
 'Exponential0': Exponential(name=Exponential0, mode=NonBatchingMode, 
             pre=LIF(name=LIF0, mode=NonBatchingMode, size=(8,)), 
             post=LIF(name=LIF0, mode=NonBatchingMode, size=(8,))),
 'Exponential1': Exponential(name=Exponential1, mode=NonBatchingMode, 
             pre=LIF(name=LIF0, mode=NonBatchingMode, size=(8,)), 
             post=LIF(name=LIF1, mode=NonBatchingMode, size=(2,))),
 'Exponential2': Exponential(name=Exponential2, mode=NonBatchingMode, 
             pre=LIF(name=LIF1, mode=NonBatchingMode, size=(2,)), 
             post=LIF(name=LIF0, mode=NonBatchingMode, size=(8,))),
 'Exponential3': Exponential(name=Exponential3, mode=NonBatchingMode, 
             pre=LIF(name=LIF1, mode=NonBatchingMode, size=(2,)), 
             post=LIF(name=LIF1, mode=NonBatchingMode, size=(2,))),
 'LIF0': LIF(name=LIF0, mode=NonBatchingMode, size=(8,)),
 'LIF1': LIF(name=LIF1, mode=NonBatchingMode, size=(2,)),
 'COBA2': COBA,
 'COBA4': COBA,
 'COBA3': COBA,
 'COBA5': COBA}

Note in the above EINet, we do not define the update() function. This is because any subclass of brainpy.dyn.Network has a default update function, in which it automatically gathers the elements defined in this network and sequentially runs the update function of each element.

Let’s try to simulate our defined EINet model.

net = EINet(3200, 800, method='exp_auto')  # "method": the numerical integrator method

runner = bp.DSRunner(net,
                         monitors=['E.spike', 'I.spike'],
                         inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
                         title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
                         title='Spikes of Inhibitory Neurons', show=True)
Used time None s
_images/16f64f02ffb29e33cdd4a85e9a27b6d12a407f779c7d4215246830e8518ca445.png _images/3087ac52fb2ec61d8cf39782ae5c5ff4f8a79bc841dc95978710998148095d51.png
2. Instantiating a network directly#

Another way to instantiate a network model is directly pass the elements into the constructor of brainpy.Network. It receives *args and **kwargs arguments.

# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
E = LIF(3200, **pars)
I = LIF(800, **pars)
E.V.value = bp.math.random.randn(E.num) * 2 - 55.
I.V.value = bp.math.random.randn(I.num) * 2 - 55.

# synapses
E_pars = dict(output=bp.synouts.COBA(E=0.), g_max=0.6, tau=5.)
I_pars = dict(output=bp.synouts.COBA(E=-80.), g_max=6.7, tau=10.)
E2E = Exponential(E, E, bp.conn.FixedProb(prob=0.02), **E_pars)
E2I = Exponential(E, I, bp.conn.FixedProb(prob=0.02), **E_pars)
I2E = Exponential(I, E, bp.conn.FixedProb(prob=0.02), **I_pars)
I2I = Exponential(I, I, bp.conn.FixedProb(prob=0.02), **I_pars)


# Network
net2 = bp.Network(E2E, E2I, I2E, I2I, exc_group=E, inh_group=I)

All elements are passed as **kwargs argument can be accessed by the provided keys. This will affect the following dynamics simulation.

net2.exc_group
LIF(name=LIF4, mode=NonBatchingMode, size=(3200,))
net2.inh_group
LIF(name=LIF5, mode=NonBatchingMode, size=(800,))

After construction, the simulation goes the same way:

runner = bp.DSRunner(net2,
                         monitors=['exc_group.spike', 'inh_group.spike'],
                         inputs=[('exc_group.input', 20.), ('inh_group.input', 20.)])
t = runner.run(100.)
print(f'Used time {t} s')

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['exc_group.spike'],
                         title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['inh_group.spike'],
                         title='Spikes of Inhibitory Neurons', show=True)
Used time None s
_images/c934a0427b1600456c0a3801446fa5bc2efe05ff79356597cc48ed05ae127ef5.png _images/eaffe89ef815429bda45e0bd3eedbd44462a834f29413c65f8e0a6c19f1a3a9c.png

Customizing update function#

If you want to control your updating logic in a network, you can overwrite the updating function update(tdi) and customize it by yourself.

For the above E/I balanced network model, we can define its update function as:

class EINetV2(bp.Network):
  def __init__(self, num_exc, num_inh, method='exp_auto', **kwargs):
    super(EINetV2, self).__init__(**kwargs)

    # neurons
    self.N = LIF(num_exc + num_inh,
                 V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                 method=method, V_initializer=bp.init.Normal(-55., 2.))

    # synapses
    self.Esyn = bp.synapses.Exponential(self.N[:num_exc], self.N,
                                        bp.conn.FixedProb(prob=0.02),
                                        output=bp.synouts.COBA(E=0.),
                                        g_max=0.6, tau=5.,
                                        method=method)
    self.Isyn = bp.synapses.Exponential(self.N[num_exc:], self.N,
                                        bp.conn.FixedProb(prob=0.02),
                                        output=bp.synouts.COBA(E=-80.),
                                        g_max=6.7, tau=10.,
                                        method=method)

  def update(self, tdi):
    self.Esyn(tdi)
    self.Isyn(tdi)
    self.N(tdi)
    self.update_local_delays()  # IMPORTANT

In the above, we define one population, and create E (excitatory) and I (inhibitory) projections within this population. Then, we first update synapse models by calling self.Esyn(tdi) and self.Isyn(tdi). This operation can ensure that all synapse inputs can be gathered onto neuron models before we update neurons. After updating synapses, we update the state of neurons by calling self.N(tdi). Finally, it’s worthy to note that we need to update all delays used in this network through self.update_local_delays(). This is because delay variables relying on neurons. Once upon neuronal states have been updated, we need to update delays according to these new values of neuronal states.

net = EINetV2(3200, 800)
runner = bp.DSRunner(net, monitors={'spikes': net.N.spike}, inputs=[(net.N.input, 20.)])
runner.run(100.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)
_images/4c4e2608b90f1b28ad9047ccbbf53c3534c10a80360cf0f553b07b8c8ddb82f3.png

Customizing Your Neuron Models#

@Xiaoyu Chen @Chaoming Wang

The previous section shows all available models users can utilize by simply instantiating the abstract model. In following sections we will dive into details to illustrate how to build a neuron model with brainpy.dyn.NeuGroup. Neurons are the most basic components in neural dynamics simulation. In BrainPy, brainpy.dyn.NeuGroup is used for neuron modeling.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bp.__version__
'2.3.1'

brainpy.dyn.NeuGroup#

Generally, any neuron model can evolve continuously or discontinuously. Discontinuous evolution may be triggered by events, such as the reset of membrane potential. Moreover, it is common in a neural system that a dynamical system has different states, such as the excitable or refractory state in a leaky integrate-and-fire (LIF) model. In this section, we will use two examples to illustrate how to capture these complexities in neuron modeling.

Defining a neuron model in BrainPy is simple. You just need to inherit from brainpy.dyn.NeuGroup, and satisfy the following two requirements:

  • Providing the size of the neural group in the constructor when initialize a new neural group class. size can be a integer referring to the number of neurons or a tuple/list of integers referring to the geometry of the neural group in different dimensions. According to the provided group size, NeuroGroup will automatically calculate the total number num of neurons in this group.

  • Creating an update(tdi) function. Update function provides the rule how the neuron states are evolved from the current time \(\mathrm{tdi.t}\) to the next time \(\mathrm{tdi.t + tdi.dt}\).

In the following part, a Hodgkin-Huxley (HH) model is used as an example for illustration.

Hodgkin–Huxley Model#

The Hodgkin-Huxley (HH) model is a continuous-time dynamical system. It is one of the most successful mathematical models of a complex biological process that has ever been formulated. Changes of the membrane potential influence the conductance of different channels, elaborately modeling the neural activities in biological systems. Mathematically, the model is given by:

\[\begin{split} \begin{aligned} C_m \frac {dV} {dt} &= -(\bar{g}_{Na} m^3 h (V -E_{Na}) + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) \quad\quad(1) \\ \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} \quad\quad(2) \\ &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} \quad\quad(3) \\ &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) \quad\quad(4) \\ &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) \quad\quad(5) \\ &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} \quad\quad(6) \\ &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} \quad\quad(7) \\ &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) \quad\quad(8) \\ \end{aligned} \end{split}\]

where \(V\) is the membrane potential, \(C_m\) is the membrane capacitance per unit area, \(E_K\) and \(E_{Na}\) are the potassium and sodium reversal potentials, respectively, \(E_l\) is the leak reversal potential, \(\bar{g}_K\) and \(\bar{g}_{Na}\) are the potassium and sodium conductance per unit area, respectively, and \(\bar{g}_l\) is the leak conductance per unit area. Because the potassium and sodium channels are voltage-sensitive, according to the biological experiments, \(m\), \(n\) and \(h\) are used to simulate the activation of the channels. Specially, \(n\) measures the activation of potassium channels, and \(m\) and \(h\) measures the activation and inactivation of sodium channels, respectively. \(\alpha_{x}\) and \(\beta_{x}\) are rate constants for the ion channel x and depend exclusively on the membrane potential.

To implement the HH model, variables should be specified. According to the above equations, the following five state variables change with respect to time:

  • V: the membrane potential

  • m: the activation of sodium channels

  • h: the inactivation of sodium channels

  • n: the activation of potassium channels

  • input: the external/synaptic input

Besides, the spiking state and the last spiking time can also be recorded for statistic analysis:

  • spike: whether a spike is produced

  • t_last_spike: the last spiking time

Based on these state variables, the HH model can be implemented as below.

class HH(bp.NeuGroup):
  def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
               V_th=20., C=1.0, **kwargs):
    # providing the group "size" information
    super(HH, self).__init__(size=size, **kwargs)

    # initialize parameters
    self.ENa = ENa
    self.EK = EK
    self.EL = EL
    self.gNa = gNa
    self.gK = gK
    self.gL = gL
    self.C = C
    self.V_th = V_th

    # initialize variables
    self.V = bm.Variable(bm.random.randn(self.num) - 70.)
    self.m = bm.Variable(0.5 * bm.ones(self.num))
    self.h = bm.Variable(0.6 * bm.ones(self.num))
    self.n = bm.Variable(0.32 * bm.ones(self.num))
    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

    # integral functions
    self.int_V = bp.odeint(f=self.dV, method='exp_auto')
    self.int_m = bp.odeint(f=self.dm, method='exp_auto')
    self.int_h = bp.odeint(f=self.dh, method='exp_auto')
    self.int_n = bp.odeint(f=self.dn, method='exp_auto')

  def dV(self, V, t, m, h, n, Iext):
    I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
    I_K = (self.gK * n ** 4.0) * (V - self.EK)
    I_leak = self.gL * (V - self.EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
    return dVdt

  def dm(self, m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return dmdt
  
  def dh(self, h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return dhdt

  def dn(self, n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return dndt

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    # compute V, m, h, n
    V = self.int_V(self.V, _t, self.m, self.h, self.n, self.input, dt=_dt)
    self.h.value = self.int_h(self.h, _t, self.V, dt=_dt)
    self.m.value = self.int_m(self.m, _t, self.V, dt=_dt)
    self.n.value = self.int_n(self.n, _t, self.V, dt=_dt)

    # update the spiking state and the last spiking time
    self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)

    # update V
    self.V.value = V

    # reset the external input
    self.input[:] = 0.

When defining the HH model, equation (1) is accomplished by brainpy.odeint as an ODEIntegrator. The details are contained in the Numerical Solvers for ODEs tutorial.

The variables, which will be updated during dynamics simulation, should be packed as brainpy.math.Variable and thus can be processed by JIT compliers to accelerate simulation.

In the following part, a leaky integrate-and-fire (LIF) model is introduced as another example for illustration.

Leaky Integrate-and-Fire Model#

The LIF model is the classical neuron model which contains a continuous process and a discontinous spike reset operation. Formally, it is given by:

\[\begin{split} \begin{aligned} \tau_m \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \quad\quad (1) \\ \text{if} \, V(t) \gt V_{th}, V(t) =V_{rest} \, \text{after} \, \tau_{ref} \, \text{ms} \quad\quad (2) \end{aligned} \end{split}\]

where \(V\) is the membrane potential, \(V_{rest}\) is the rest membrane potential, \(V_{th}\) is the spike threshold, \(\tau_m\) is the time constant, \(\tau_{ref}\) is the refractory time period, and \(I\) is the time-variant synaptic inputs.

The above two equations model the continuous change and the spiking of neurons, respectively. Moreover, it has multiple states: subthreshold state, and spiking or refractory state. The membrane potential \(V\) is integrated according to equation (1) when it is below \(V_{th}\). Once \(V\) reaches the threshold \(V_{th}\), according to equation (2), \(V\) is reaet to \(V_{rest}\), and the neuron enters the refractory period where the membrane potential \(V\) will remain constant in the following \(\tau_{ref}\) ms.

The neuronal variables, like the membrane potential and external input, can be captured by the following two variables:

  • V: the membrane potential

  • input: the external/synaptic input

In order to define the different states of a LIF neuron, we define additional variables:

  • spike: whether a spike is produced

  • refractory: whether the neuron is in the refractory period

  • t_last_spike: the last spiking time

Based on these state variables, the LIF model can be implemented as below.

class LIF(bp.NeuGroup):
  def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., R=1., tau=10., t_ref=5., **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)

    # initialize parameters
    self.V_rest = V_rest
    self.V_reset = V_reset
    self.V_th = V_th
    self.R = R
    self.tau = tau
    self.t_ref = t_ref

    # initialize variables
    self.V = bm.Variable(bm.random.randn(self.num) + V_reset)
    self.input = bm.Variable(bm.zeros(self.num))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
    self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # integral function
    self.integral = bp.odeint(f=self.derivative, method='exp_auto')

  def derivative(self, V, t, Iext):
    dvdt = (-V + self.V_rest + self.R * Iext) / self.tau
    return dvdt

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    # Whether the neurons are in the refractory period
    refractory = (_t - self.t_last_spike) <= self.t_ref
    
    # compute the membrane potential
    V = self.integral(self.V, _t, self.input, dt=_dt)
    
    # computed membrane potential is valid only when the neuron is not in the refractory period 
    V = bm.where(refractory, self.V, V)
    
    # update the spiking state
    spike = self.V_th <= V
    self.spike.value = spike
    
    # update the last spiking time
    self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
    
    # update the membrane potential and reset spiked neurons
    self.V.value = bm.where(spike, self.V_reset, V)
    
    # update the refractory state
    self.refractory.value = bm.logical_or(refractory, spike)
    
    # reset the external input
    self.input[:] = 0.

In above, the discontinous resetting is implemented with brainpy.math.where operation.

Instantiation and running#

Here, let’s try to instantiate a HH neuron group:

neu = HH(10)

in which a neural group containing 10 HH neurons is generated.

The details of the model simulation will be expanded in the Runners section. In brief, running any dynamical system instance should be accomplished with a runner, such like brianpy.DSRunner and brainpy.ReportRunner. The variables to be monitored and the input crrents to be applied in the simulation can be provided when initializing the runner. The details are accessible in Monitors and Inputs.

runner = bp.DSRunner(
    neu, 
    monitors=['V'], 
    inputs=('input', 22.)  # constant external inputs of 22 mA to all neurons
)

Then the simulation can be performed with a given time period, and the simulation result can be visualized:

runner.run(200)  # the running time is 200 ms

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/c2c3aac34c09b9c27af650eed80d7bec04f396646b77542a498071c226294254.png

A LIF neural group can be instantiated and applied in simulation in a similar way:

group = LIF(10)

runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 22.), jit=True)
runner.run(200)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
_images/8c6beddf759e343da72d4c7b3bcb192d7f17ff6d0b42152623b384848de8860e.png

Customizing Your Synapse Models#

@Chaoming Wang @Xiaoyu Chen

Synaptic computation is the core of brain dynamics programming. This is because in a real project most of the simulation time spends on the computation of synapses. In order to achieve efficient synaptic computation, BrainPy provides many useful supports. Here, we are going to explore the details of these supports.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')

bp.__version__
'2.3.1'

Synapse Models in Math#

Before we talk about the implementation of synapses in BrainPy, it’s better to understand the targets (synapse models) we are going to implement. For different illustration purposes, we are going to implement two synapse models: exponential synapse model and AMPA synapse model.

1. The exponential synapse model#

The exponential synapse model assumes that once a pre-synaptic neuron generates a spike, the synaptic state arises instantaneously, then decays with a certain time constant \(\tau_{decay}\). Its dynamics is given by:

\[ \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-D-t^{k}) \]

where \(g\) is the synaptic state, \(t^{k}\) is the spike time of the pre-synaptic neuron, and \(D\) is the synaptic delay.

Afterward, the current output onto the post-synaptic neuron is given in the conductance-based form:

\[ I_{syn}(t) = g_{max} g \left( V-E \right) \]

where \(E\) is the reversal potential of the synapse, \(V\) is the post-synaptic membrane potential, \(g_{max}\) is the maximum synaptic conductance.

2. The AMPA synapse model#

A classical model of AMPA synapse is to use the Markov process to model ion channel switch. Here \(g\) represents the probability of channel opening, \(1-g\) represents the probability of ion channel closing, and \(\alpha\) and \(\beta\) are the transition probability. Specifically, its formula is given by

\[ \frac{dg}{dt} =\alpha[T](1-g)-\beta g \]

where \(\alpha [T]\) denotes the transition probability from state \((1-g)\) to state \((g)\); and \(\beta\) represents the transition probability of the other direction. \(\alpha\) is the binding constant. \(\beta\) is the unbinding constant. \([T]\) is the neurotransmitter concentration, and has the duration of 0.5 ms.

Moreover, the post-synaptic current on the post-synaptic neuron is formulated as

\[I_{syn} = g_{max} g (V-E)\]

where \(g_{max}\) is the maximum conductance, and \(E\) is the reverse potential.

Synapse Models in Silicon#

The implementation of synapse models is accomplished by brainpy.dyn.TwoEndConn interface. In this section, we talk about what supports are provided for the implementation of synapse models in silicon.

1. brainpy.dyn.TwoEndConn#

In BrainPy, brainpy.dyn.SynConn is used to model two-end synaptic computations.

To define a synapse model, two requirements should be satisfied:

1. Constructor function __init__(), in which three key arguments are needed.

  • pre: the pre-synaptic neural group. It should be an instance of brainpy.dyn.NeuGroup.

  • post: the post-synaptic neural group. It should be an instance of brainpy.dyn.NeuGroup.

  • conn (optional): the connection type between these two groups. BrainPy has provided abundant connection types that are described in details in the Synaptic Connections.

2. Update function update(tdi) describes the updating rule from the current time \(\mathrm{tdi.t}\) to the next time \(\mathrm{tdi.t + tdi.dt}\).

2. Variable delays#

As seen in the above two synapse models, synaptic computations are usually involved with variable delays. A delay time (typically 0.3–0.5 ms) is usually required for a neurotransmitter to be released from a pre-synaptic membrane, diffuse across the synaptic cleft, and bind to a receptor site on the post-synaptic membrane.

BrainPy provides several kinds of delay variables for users, including:

  • brainpy.math.LengthDelay: a delay variable which defines a constant steps for delay.

  • brainpy.math.TimeDelay: a delay variable which defines a constant time length for delay.

Assume here we need a delay variable which has 1 ms delay. If the numerical integration precision dt is 0.1 ms, then we can create a brainpy.math.LengthDelay which has 10 delay time steps.

target_data_to_delay = bm.Variable(bm.zeros(10))

example_delay = bm.LengthDelay(target_data_to_delay,
                               delay_len=10)  # delay 10 steps
example_delay(5)  # call the delay data at 5 delay step
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(10)  # call the delay data at 10 delay step
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

Alternatively, we can create an instance of brainpy.math.TimeDelay, which use time t as the index to retrieve the delay data.

t0 = 0.
example_delay = bm.TimeDelay(target_data_to_delay,
                             delay_len=1.0, t0=t0)  # delay 1.0 ms
example_delay(t0 - 1.0)  # the delay data at t-1. ms
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
example_delay(t0 - 0.5)  # the delay data at t-0.5 ms
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
3. Synaptic connections#

Synaptic computations usually need to create connection between groups. BrainPy provides many wonderful supports to construct synaptic connections. Simply speaking, brainpy.conn.Connector can create various data structures you want through the require() function. Take the random connection brainpy.conn.FixedProb which will be used in follows as the example,

example_conn = bp.conn.FixedProb(0.2)(pre_size=(5,), post_size=(8, ))

we can require the connection matrix (has the shape of (num_pre, num_post):

example_conn.require('conn_mat')
Array([[False,  True, False, False, False, False,  True, False],
       [False,  True,  True, False, False,  True, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False,  True, False],
       [False, False, False,  True,  True, False, False, False]],      dtype=bool)

we can also require the connected indices of pre-synaptic neurons (pre_ids) and post-synaptic neurons (post_ids):

example_conn.require('pre_ids', 'post_ids')
(Array([0, 1, 2, 3, 4], dtype=int32), Array([3, 0, 2, 3, 7], dtype=int32))

Or, we can require the connection structure of pre2post which stores the information how does each pre-synaptic neuron connect to post-synaptic neurons:

example_conn.require('pre2post')
(Array([7, 6, 1, 4, 1], dtype=int32), Array([0, 1, 2, 3, 4, 5], dtype=int32))

Warning

Every require() function will establish a new connection pattern, and return the data structure users have required. Therefore any two require() will return different connection pattern, just like the examples above. Please keep in mind to require all the data structure at once if users want a consistent connection pattern.

More details of the connection structures please see the tutorial of Synaptic Connections.

Achieving efficient synaptic computation is difficult#

Synaptic computations usually need to transform the data of the pre-synaptic dimension into the data of the post-synaptic dimension, or the data with the shape of the synapse number. There does not exist a universal computation method that are efficient in all cases. Usually, we need different ways for different connection situations to achieve efficient synaptic computation. In the next two sections, we will talk about how to define efficient synaptic models when your connections are sparse or dense.

Before we start, we need to define some useful helper functions to define and show synapse models. Then, we will highlight the key differences of model definition when using different synaptic connections.

# Basic Model to define the exponential synapse model. This class 
# defines the basic parameters, variables, and integral functions. 


class BaseExpSyn(bp.SynConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., method='exp_auto'):
    super(BaseExpSyn, self).__init__(pre=pre, post=post, conn=conn)

    # check whether the pre group has the needed attribute: "spike"
    self.check_pre_attrs('spike')

    # check whether the post group has the needed attribute: "input" and "V"
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max

    # use "LengthDelay" to store the spikes of the pre-synaptic neuron group
    self.delay_step = int(delay/bm.get_dt())
    self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)

    # integral function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
# Basic Model to define the AMPA synapse model. This class 
# defines the basic parameters, variables, and integral functions. 


class BaseAMPASyn(bp.SynConn):
  def __init__(self, pre, post, conn, delay=0., g_max=0.42, E=0., alpha=0.98,
               beta=0.18, T=0.5, T_duration=0.5, method='exp_auto'):
    super(BaseAMPASyn, self).__init__(pre=pre, post=post, conn=conn)

    # check whether the pre group has the needed attribute: "spike"
    self.check_pre_attrs('spike')

    # check whether the post group has the needed attribute: "input" and "V"
    self.check_post_attrs('input', 'V')

    # parameters
    self.delay = delay
    self.g_max = g_max
    self.E = E
    self.alpha = alpha
    self.beta = beta
    self.T = T
    self.T_duration = T_duration

    # use "LengthDelay" to store the spikes of the pre-synaptic neuron group
    self.delay_step = int(delay/bm.get_dt())
    self.pre_spike = bm.LengthDelay(pre.spike, self.delay_step)

    # store the arrival time of the pre-synaptic spikes
    self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7)

    # integral function
    self.integral = bp.odeint(self.derivative, method=method)

  def derivative(self, g, t, TT):
    dg = self.alpha * TT * (1 - g) - self.beta * g
    return dg
# for more details of how to run a simulation please see the tutorials in "Dynamics Simulation"

def show_syn_model(model):
  pre = bp.neurons.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
  post = bp.neurons.LIF(1, V_rest=-60., V_reset=-60., V_th=-40.)
  syn = model(pre, post, conn=bp.conn.One2One())
  net = bp.Network(pre=pre, post=post, syn=syn)

  runner = bp.DSRunner(net,
                       monitors=['pre.V', 'post.V', 'syn.g'],
                       inputs=['pre.input', 22.])
  runner.run(100.)

  fig, gs = bp.visualize.get_figure(1, 2, 3, 4)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['syn.g'], legend='syn.g')
  fig.add_subplot(gs[0, 1])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['pre.V'], legend='pre.V')
  bp.visualize.line_plot(runner.mon.ts, runner.mon['post.V'], legend='post.V', show=True)

Computation with Dense Connections#

Matrix-based synaptic computation is straightforward. Especially, when your models are connected densely, using matrix is highly efficient.

conn_mat#

Assume two neuron groups are connected through a fixed probability of 0.7.

conn = bp.conn.FixedProb(0.7)(pre_size=6, post_size=8)

Then you can create the connection matrix though conn.require("conn_mat"):

conn.require('conn_mat')
Array([[ True, False,  True,  True,  True,  True,  True, False],
       [False,  True, False,  True, False,  True,  True,  True],
       [ True,  True,  True, False, False,  True,  True, False],
       [ True,  True,  True, False, False,  True,  True,  True],
       [ True,  True,  True,  True, False,  True,  True,  True],
       [ True,  True,  True, False,  True,  True,  True,  True]],      dtype=bool)

conn_mat has the shape of (num_pre, num_post). Therefore, transforming the data with the pre-synaptic dimension into the date of the post-synaptic dimension is very easy. You just need make a matrix multiplication: brainpy.math.dot(pre_values,  conn_mat) (\(\mathbb{R}^\mathrm{num\_pre} @ \mathbb{R}^\mathrm{(num\_pre, num\_post)} \to \mathbb{R}^\mathrm{num\_post}\)).

With the synaptic connection of conn_mat in above, we can define the exponential synapse model as the follows. It’s worthy to note that the evolution of states ouput onto the same post-synaptic neurons in exponential synapses can be superposed. This means we can declare the synapse variables with the shape of post-synaptic group, rather than the number of the total synapses.

class ExpConnMat(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpConnMat, self).__init__(*args, **kwargs)

    # connection matrix
    self.conn_mat = self.conn.require('conn_mat').astype(float)

    # synapse gating variable
    # -------
    # NOTE: Here the synapse number is the same with 
    #       the post-synaptic neuron number. This is 
    #       different from the AMPA synapse.
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    # pull the delayed pre spikes for computation
    delayed_spike = self.pre_spike(self.delay_step)
    # push the latest pre spikes into the bottom
    self.pre_spike.update(self.pre.spike)
    # integrate the synapse state
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # update synapse states according to the pre spikes
    post_sps = bm.dot(delayed_spike.astype(float), self.conn_mat)
    self.g += post_sps
    # get the post-synaptic current
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpConnMat)
_images/161578dd8e909cdc1b15fa7dfc1ceb0c4df86bbc7c0980454225ee200cf4ae40.png

We can also use conn_mat to define an AMPA synapse model. Note here the shape of the synapse variable \(g\) is (num_pre, num_post), rather than self.post.num in the above exponential synapse model. This is because the synaptic states of AMPA model can not be superposed.

class AMPAConnMat(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPAConnMat, self).__init__(*args, **kwargs)

    # connection matrix
    self.conn_mat = self.conn.require('conn_mat').astype(float)

    # synapse gating variable
    # -------
    # NOTE: Here the synapse shape is (num_pre, num_post),
    #       in contrast to the ExpConnMat
    self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    # pull the delayed pre spikes for computation
    delayed_spike = self.pre_spike(self.delay_step)
    # push the latest pre spikes into the bottom
    self.pre_spike.update(self.pre.spike)
    # get the time of pre spikes arrive at the post synapse
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    # get the neurotransmitter concentration at the current time
    TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
    # integrate the synapse state
    TT = TT.reshape((-1, 1)) * self.conn_mat  # NOTE: only keep the concentrations
                                              #       on the invalid connections
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    # get the post-synaptic current
    g_post = self.g.sum(axis=0)
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAConnMat)
_images/cc727f9a8f4924c4f3477fde7b0ed5561c05515b4bd23e45f3a78bf307021c51.png
Special connections#

Sometimes, we can define some synapse models with special connection types, such as all-to-all connection, or one-to-one connection. For these special situations, even the connection information can be ignored, i.e., we do not need conn_mat or other structures anymore.

Assume the pre-synaptic group connects to the post-synaptic group with a all-to-all fashion. Then, exponential synapse model can be defined as,

class ExpAll2All(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpAll2All, self).__init__(*args, **kwargs)

    # synapse gating variable
    # -------
    # The synapse variable has the shape of the post-synaptic group
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.g.value = self.integral(self.g, _t, dt=_dt)
    self.g += delayed_spike.sum()  # NOTE: HERE is the difference
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpAll2All)
_images/161578dd8e909cdc1b15fa7dfc1ceb0c4df86bbc7c0980454225ee200cf4ae40.png

Similarly, the AMPA synapse model can be defined as

class AMPAAll2All(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPAAll2All, self).__init__(*args, **kwargs)

    # synapse gating variable
    # -------
    # The synapse variable has the shape of the post-synaptic group
    self.g = bm.Variable(bm.zeros((self.pre.num, self.post.num)))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T
    TT = TT.reshape((-1, 1))  # NOTE: here is the difference
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    g_post = self.g.sum(axis=0) # NOTE: here is also different
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPAAll2All)
_images/cc727f9a8f4924c4f3477fde7b0ed5561c05515b4bd23e45f3a78bf307021c51.png

Actually, the synaptic computation with these special connections can be very efficient! A concrete example please see a decision making spiking model in BrainPy-Examples. This implementation achieve at least four times acceleration comparing to the implementation in other frameworks.

Computation with Sparse Connections#

However, in the real neural system, the neurons are connected sparsely in essence.

Imaging you want to connect 10,000 pre-synaptic neurons to 10,000 post-synaptic neurons with a 10% random connection probability. Using matrix, you need \(10^8\) floats to save the synaptic state, and at each update step, you need do computation on \(10^8\) floats. Actually, the number of synapses you really connect is only \(10^7\). See, there is a huge memory waste and computing resource inefficiency. Moreover, at the given time \(\mathrm{\_t}\), the number of pre-synaptic neurons in the spiking state is small on average. This means we have made many useless computations when defining synaptic computations with matrix-based connections (zeros dot connection matrix results in zeros).

Therefore, we need new ways to define synapse models. Specifically, we use vectors to store the connected neuron indices, like the pre_ids and post_ids (see Synaptic Connections).

In the below, we assume you have learned the synaptic connection types detailed in the tutorial of Synaptic Connections.

The pre2post operator#

A notable difference of brain dynamics models from the deep learning is that they are sparse and event-driven. In order to support this significant different kind of computations, BrainPy has built many useful operators. In this section, we talk about a set of operators needed in pre2post computations.

Note before we have said that exponential synapse model can make computations at the dimension of the post-synaptic group. Therefore, we can directly transform the pre-synaptic data into the data of the post-synaptic shape. brainpy.math.pre2post_event_sum(events, pre2post, post_num, values) can satisfy your requirements. This operator needs the synaptic structure of pre2post (a tuple contains the post_ids and idnptr of pre-synaptic neurons).

If values is a scalar, pre2post_event_sum is equivalent to:

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
  if events[i]:
    for j in range(idnptr[i], idnptr[i+1]):
      post_val[post_ids[i]] += values

If values is a vector, pre2post_event_sum is equivalent to:

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
  if events[i]:
    for j in range(idnptr[i], idnptr[i+1]):
      post_val[post_ids[i]] += values[j]

With this operator, exponential synapse model can be defined as:

class ExpSparse(BaseExpSyn):
  def __init__(self, *args, **kwargs):
    super(ExpSparse, self).__init__(*args, **kwargs)

    # connections
    self.pre2post = self.conn.require('pre2post')

    # synapse variable
    self.g = bm.Variable(bm.zeros(self.post.num))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    self.g.value = self.integral(self.g, _t, dt=_dt)
    # NOTE: update synapse states according to the pre spikes
    post_sps = bm.pre2post_event_sum(delayed_spike, self.pre2post, self.post.num, 1.)
    self.g += post_sps
    self.post.input += self.g_max * self.g * (self.E - self.post.V)
show_syn_model(ExpSparse)
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.1-py3.9.egg\brainpy\math\operators\pre_syn_post.py:94: UserWarning: Please use ``brainpylib.event_ops.event_csr_matvec()`` instead.
  warnings.warn('Please use ``brainpylib.event_ops.event_csr_matvec()`` instead.', UserWarning)
_images/161578dd8e909cdc1b15fa7dfc1ceb0c4df86bbc7c0980454225ee200cf4ae40.png

This model will be very efficient when your synapses are connected sparsely.

The pre2syn and syn2post operators#

However, for AMPA synapse model, the pre-synaptic values can not be directly transformed into the post-synaptic dimensional data. Therefore, we need to first change the pre data into the data of the synapse dimension, then transform the synapse-dimensional data into the post-dimensional data.

Therefore, the core problem of synaptic computation is how to convert values among different shape of arrays. Specifically, in the above AMPA synapse model, we have three kinds of array shapes (see the following figure): arrays with the dimension of pre-synaptic group, arrays of the dimension of post-synaptic group, and arrays with the shape of synaptic connections. Converting the pre-synaptic spiking state into the synaptic state and grouping the synaptic variable as the post-synaptic current value are central problems of synaptic computation.

Here BrainPy provides two operators brainpy.math.pre2syn(pre_values, pre_ids) and brainpy.math.syn2post(syn_values, post_ids, post_num) to convert vectors among different dimensions.

  • brainpy.math.pre2syn() receives two arguments: “pre_values” (the variable of the pre-synaptic dimension) and “pre_ids” (the connected pre-synaptic neuron index).

  • brainpy.math.syn2post() receives three arguments: “syn_values” (the variable with the synaptic size), “post_ids” (the connected post-synaptic neuron index) and “post_num” (the number of the post-synaptic neurons).

Based on these two operators, we can define the AMPA synapse model as:

class AMPASparse(BaseAMPASyn):
  def __init__(self, *args, **kwargs):
    super(AMPASparse, self).__init__(*args, **kwargs)

    # connection matrix
    self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')

    # synapse gating variable
    # -------
    # NOTE: Here the synapse shape is (num_syn,)
    self.g = bm.Variable(bm.zeros(len(self.pre_ids)))

  def update(self, tdi, x=None):
    _t, _dt = tdi.t, tdi.dt
    delayed_spike = self.pre_spike(self.delay_step)
    self.pre_spike.update(self.pre.spike)
    # get the time of pre spikes arrive at the post synapse
    self.spike_arrival_time.value = bm.where(delayed_spike, _t, self.spike_arrival_time)
    # get the arrival time with the synapse dimension
    arrival_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
    # get the neurotransmitter concentration at the current time
    TT = ((_t - arrival_times) < self.T_duration) * self.T
    # integrate the synapse state
    self.g.value = self.integral(self.g, _t, TT, dt=_dt)
    # get the post-synaptic current
    g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
    self.post.input += self.g_max * g_post * (self.E - self.post.V)
show_syn_model(AMPASparse)
_images/cc727f9a8f4924c4f3477fde7b0ed5561c05515b4bd23e45f3a78bf307021c51.png

We hope this tutorial will help your synapse models be defined efficiently.

Model Simulation#

Simulation with DSRunner#

@Tianqiu Zhang @Chaoming Wang @Xiaoyu Chen

The convenient simulation interface for dynamical systems in BrainPy is implemented by brainpy.dyn.DSRunner. It can simulate various levels of models including channels, neurons, synapses and systems. In this tutorial, we will introduce how to use brainpy.dyn.DSRunner in detail.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bp.__version__
'2.3.1'

Initializing a DSRunner#

Generally, we can initialize a runner for dynamical systems with the format of:

runner = DSRunner(target=instance_of_dynamical_system,
                  inputs=inputs_for_target_DynamicalSystem,
                  monitors=interested_variables_to_monitor,
                  dyn_vars=dynamical_changed_variables,
                  jit=enable_jit_or_not,
                  progress_bar=report_the_running_progress,
                  numpy_mon_after_run=transform_into_numpy_ndarray
                  )

In which

  • target specifies the model to be simulated. It must an instance of brainpy.DynamicalSystem.

  • inputs is used to define the input operations for specific variables.

    • It should be the format of [(target, value, [type, operation])], where target is the input target, value is the input value, type is the input type (such as “fix”, “iter”, “func”), operation is the operation for inputs (such as “+”, “-”, “*”, “/”, “=”). Also, if you want to specify multiple inputs, just give multiple (target, value, [type, operation]), such as [(target1, value1), (target2, value2)].

    • It can also be a function, which is used to manually specify the inputs for the target variables. This input function should receive one argument tdi which contains the shared arguments like time t, time step dt, and index i.

  • monitors is used to define target variables in the model. During the simulation, the history values of the monitored variables will be recorded. It can also to monitor variables by callable functions and it should be a dict. The key should be a string for later retrieval by runner.mon[key]. The value should be a callable function which receives an argument: tdt.

  • dyn_vars is used to specify all the dynamically changed variables used in the target model.

  • jit determines whether to use JIT compilation during the simulation.

  • progress_bar determines whether to use progress bar to report the running progress or not.

  • numpy_mon_after_run determines whether to transform the JAX arrays into numpy ndarray or not when the network finishes running.

Running a DSRunner#

After initialization of the runner, users can call .run() function to run the simulation. The format of function .run() is showed as follows:

runner.run(duration=simulation_time_length,
           inputs=input_data,
           reset_state=whether_reset_the_model_states,
           shared_args=shared_arguments_across_different_layers,
           progress_bar=report_the_running_progress,
           eval_time=evaluate_the_running_time
           )

In which

  • duration is the simulation time length.

  • inputs is the input data. If inputs_are_batching=True, inputs must be a PyTree of data with two dimensions: (num_sample, num_time, ...). Otherwise, the inputs should be a PyTree of data with one dimension: (num_time, ...).

  • reset_state determines whether to reset the model states.

  • shared_args is shared arguments across different layers. All the layers can access the elements in shared_args.

  • progress_bar determines whether to use progress bar to report the running progress or not.

  • eval_time determines whether to evaluate the running time.

Here we define an E/I balance network as the simulation model.

class EINet(bp.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    super(EINet, self).__init__()

    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
    self.E = bp.neurons.LIF(num_exc, **pars, method=method)
    self.I = bp.neurons.LIF(num_inh, **pars, method=method)

    # synapses
    prob = 0.1
    we = 0.6 / scale / (prob / 0.02) ** 2  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale / (prob / 0.02) ** 2  # inhibitory synaptic weight
    self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob),
                                       output=bp.synouts.COBA(E=0.), g_max=we,
                                       tau=5., method=method)
    self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(prob),
                                       output=bp.synouts.COBA(E=0.), g_max=we,
                                       tau=5., method=method)
    self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(prob),
                                       output=bp.synouts.COBA(E=-80.), g_max=wi,
                                       tau=10., method=method)
    self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(prob),
                                       output=bp.synouts.COBA(E=-80.), g_max=wi,
                                       tau=10., method=method)

Then we will wrap it into DSRunner for dynamic simulation. brainpy.dyn.DSRunner aims to provide model simulation with an outstanding performance. It takes advantage of the structural loop primitive to lower the model onto the XLA devices.

# instantiate EINet
net = EINet()
# initialize DSRunner
runner = bp.DSRunner(target=net,
                     monitors=['E.spike'],
                     inputs=[('E.input', 20.), ('I.input', 20.)],
                     jit=True)
# run the simulation
runner.run(duration=1000.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'])
_images/6cdac3561941603bd0af185a788b3df0c163821815dbc806c69be54ab9ad3ba3.png

We have run a simple example of using DSRunner, but there are many advanced usages despite this. Next we will formally introduce two main aspects that will be used frequently in DSRunner: monitors and inputs.

Monitors in DSRunner#

In BrainPy, any instance of brainpy.dyn.DSRunner has a built-in monitor. Users can set up a monitor when initializing a runner. There are multiple methods to initialize a monitor. The first method is to initialize a monitor is through a list of strings.

Initialization with a list of strings#
# initialize monitor through a list of strings
runner1 = bp.DSRunner(target=net,
                      monitors=['E.spike', 'E.V', 'I.spike', 'I.V'],  # 4 elements in monitors
                      inputs=[('E.input', 20.), ('I.input', 20.)],
                      jit=True)

where all the strings corresponds to the name of the variables in the EI network:

net.E.V, net.E.spike
(Variable([-55.31656 , -58.02285 , -61.898117, ..., -55.487587, -53.33741 ,
           -56.158283], dtype=float32),
 Variable([False, False, False, ..., False, False, False], dtype=bool))

Once we call the runner with a given time duration, the monitor will automatically record the variable evolutions in the corresponding models. Afterwards, users can access these variable trajectories by using .mon.[variable_name]. The default history times .mon.ts will also be generated after the model finishes its running. Let’s see an example.

runner1.run(100.)
bp.visualize.raster_plot(runner1.mon.ts, runner1.mon['E.spike'], show=True)
_images/4169f7fb9f96558f67a8fc59364484bda8736406a47cdcf192cb4f3d58a9a206.png
Initialization with index specification#

The second method is similar to the first one, with the difference that the index specification is added. Index specification means users only monitor the specific neurons and ignore all the other neurons. Sometimes we do not care about all the contents in a variable. We may be only interested in the values at the certain indices. Moreover, for a huge network with a long-time simulation, monitors will consume a large part of RAM. Therefore, monitoring variables only at the selected indices will be more applicable. BrainPy supports monitoring a part of elements in a Variable with the format of tuple like this:

# initialize monitor through a list of strings with index specification
runner2 = bp.DSRunner(target=net,
                      monitors=[('E.spike', [1, 2, 3]),  # monitor values of Variable at index of [1, 2, 3]
                                'E.V'],  # monitor all values of Variable 'V'
                      inputs=[('E.input', 20.), ('I.input', 20.)],
                      jit=True)
runner2.run(100.)
print('The monitor shape of "E.V" is (run length, variable size) = {}'.format(runner2.mon['E.V'].shape))
print('The monitor shape of "E.spike" is (run length, index size) = {}'.format(runner2.mon['E.spike'].shape))
The monitor shape of "E.V" is (run length, variable size) = (1000, 3200)
The monitor shape of "E.spike" is (run length, index size) = (1000, 3)
Explicit monitor target#

The third method is to use a dict with the explicit monitor target. Users can access model instance and get certain variables as monitor target:

# initialize monitor through a dict with the explicit monitor target
runner3 = bp.DSRunner(target=net,
                      monitors={'spike': net.E.spike, 'V': net.E.V},
                      inputs=[('E.input', 20.), ('I.input', 20.)],
                      jit=True)
runner3.run(100.)
print('The monitor shape of "V" is = {}'.format(runner3.mon['V'].shape))
print('The monitor shape of "spike" is = {}'.format(runner3.mon['spike'].shape))
The monitor shape of "V" is = (1000, 3200)
The monitor shape of "spike" is = (1000, 3200)
Explicit monitor target with index specification#

The fourth method is similar to the third one, with the difference that the index specification is added:

# initialize monitor through a dict with the explicit monitor target
runner4 = bp.DSRunner(target=net,
                      monitors={'E.spike': (net.E.spike, [1, 2]),  # monitor values of Variable at index of [1, 2]
                                'E.V': net.E.V},  # monitor all values of Variable 'V'
                      inputs=[('E.input', 20.), ('I.input', 20.)],
                      jit=True)
runner4.run(100.)
print('The monitor shape of "E.V" is = {}'.format(runner4.mon['E.V'].shape))
print('The monitor shape of "E.spike" is = {}'.format(runner4.mon['E.spike'].shape))
The monitor shape of "E.V" is = (1000, 3200)
The monitor shape of "E.spike" is = (1000, 2)

In spite of the four methods mentioned above, BrainPy also provides users a convenient parameter to monitor more complicate variables: fun_monitor. Users can use a function to describe monitor and pass to fun_monitor. fun_monitor must be a dict and the key should be a string for the later retrieval by runner.mon[key], the value should be a callable function which receives an arguments: tdi. The format of fun_monitor is shown as below:

fun_monitor = {'key_name': lambda tdi: body_func(tdi)}

Here we monitor a variable that

runner5 = bp.DSRunner(target=net,
                      monitors={'E-I.spike': lambda tdi: bm.concatenate((net.E.spike, net.I.spike), axis=0)},
                      inputs=[('E.input', 20.), ('I.input', 20.)],
                      jit=True)
runner5.run(100.)
bp.visualize.raster_plot(runner5.mon.ts, runner5.mon['E-I.spike'])
_images/dbf612263cccc0ae14142f018edb1c767a3b76c4a6c5c91b6e393305dacbd8b3.png

Inputs in DSRunner#

In brain dynamics simulation, various inputs are usually given to different units of the dynamical system. In BrainPy, inputs can be specified to runners for dynamical systems. The aim of inputs is to mimic the input operations in experiments like Transcranial Magnetic Stimulation (TMS) and patch clamp recording.

inputs should have the format like (target, value, [type, operation]), where

  • target is the target variable to inject the input.

  • value is the input value. It can be a scalar, a tensor, or a iterable object/function.

  • type is the type of the input value. It support two types of input: fix and iter. The first one means that the data is static; the second one denotes the data can be iterable, no matter whether the input value is a tensor or a function. The iter type must be explicitly stated.

  • operation is the input operation on the target variable. It should be set as one of { + , - , * , / , = }, and if users do not provide this item explicitly, it will be set to ‘+’ by default, which means that the target variable will be updated as val = val + input.

Users can also give multiple inputs for different target variables, like:


inputs=[(target1, value1, [type1, op1]),
        (target2, value2, [type2, op2]),
              ... ]
Static inputs#

The first example is providing static inputs. The excitation and inhibition neurons all receive the same current intensity:

runner6 = bp.DSRunner(target=net,
                      monitors=['E.spike'],
                      inputs=[('E.input', 20.), ('I.input', 20.)],  # static inputs
                      jit=True)
runner6.run(100.)
bp.visualize.raster_plot(runner6.mon.ts, runner6.mon['E.spike'])
_images/edc9b699a851fe63aad32b561c3a9b236d40588936442c0af4ec822d09f6ea2e.png
Iterable inputs#

The second example is providing iterable inputs. Users need to set type=iter and pass an iterable object or function into value:

I, length = bp.inputs.section_input(values=[0, 20., 0],
                                    durations=[100, 1000, 100],
                                    return_length=True,
                                    dt=0.1)

runner7 = bp.DSRunner(target=net,
                      monitors=['E.spike'],
                      inputs=[('E.input', I, 'iter'), ('I.input', I, 'iter')],  # iterable inputs
                      jit=True)
runner7.run(length)
bp.visualize.raster_plot(runner7.mon.ts, runner7.mon['E.spike'])
_images/ccb0a7038f18e613354ee0da5a3be20ace152ebca86f715b7ed1cbddc1761e53.png

By examples given above, users can easily understand the usage of inputs parameters. Similar to monitors, inputs can also be more complicate as a function form. BrainPy provides fun_inputs to receive the customized functional inputs created by users.

def set_input(tdi):
  net.E.input[:] = 20
  net.I.input[:] = 20.


runner8 = bp.DSRunner(target=net,
                      monitors=['E.spike'],
                      inputs=set_input,  # functional inputs
                      jit=True)
runner8.run(200.)
bp.visualize.raster_plot(runner8.mon.ts, runner8.mon['E.spike'])
_images/d04c0fc88228174ad66a83586078d3456028e5e4ee65b8bc1893c39abf5a1bf1.png

Parallel Simulation for Parameter Exploration#

@Tianqiu Zhang @Chaoming Wang

Parameter exploration and selection is an essential part in brain dynamics modeling. In general, there are two problems for the parameter exploration:

  1. how to run multiple models concurrently?

  2. how to manage device memory allowing multiple models to run concurrently?

First, most of the BrainPy models supports multiple kinds of parallelization, including parallelization of multi-threading and multi-processing on a single machine, and parallelization across multiple devices. In the below, we will illustrate these parallelization APIs one-by-one.

Second, every call of a BrainPy model will consume a fraction of device memory. Therefore, BrainPy provides a API brainpy.math.clear_buffer_memory() for memory clean.

In the following, we will illustrate how to combine them together to get an efficient parameter exploration for your models.

import brainpy as bp
import brainpy.math as bm
import numpy as np

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Parallelization across different CPU processors#

Parallelization across multiple CPU processors can be easily achieved with a single line of functional call brainpy.running.cpu_ordered_parallel(). The following pseudocode demonstrates the usage of this API.

import brainpy as bp

# define your function
def run_model(par):
  model = YourModel(par)
  runner = bp.dyn.DSRunner(model)
  runner.run(duration)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# run models in Jupyter
results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)

# run models in python file
if __name__ == '__main__':
  results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)

We will use a simple HH neuron model as an example to show this kind of parallelization method. In this example, we use multi-processing technique to test four different current values as input.

First, define your running function with the well-defined input and output data.

def hh_spike_num(bg_current): # "input" is the bg_current
  import brainpy as bp  # needed to reimport packages when
                        # run the function in Jupyter
  model = bp.neurons.HH(1)
  runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
  runner.run(1000.)
  return runner.mon['spike'].sum()  # "output" is the spike number

Then, define all your parameter spaces.

current = bm.linspace(1, 10.1, 10)  # here only one parameter

Finally, run your model concurrently with the parallelization syntax.

r = bp.running.cpu_ordered_parallel(hh_spike_num, [current], num_process=10)

r
[0, 0, 1, 48, 53, 0, 54, 63, 66, 68]

However, the above usage will accumulate buffer memory in the running device. If your single model occupies too much memory, the out-of-memory error will be raised during the parameter exploration.

A simple way to solve this issue is clear all buffers after each running of the function. For example, before returning your results, call brainpy.math.clear_buffer_memory() first.

def hh_spike_num2(bg_current): # "input" is the bg_current
  import brainpy as bp  # needed to reimport packages when
                        # run the function in Jupyter

  bg_current = bp.math.as_jax(bg_current)
  model = bp.neurons.HH(1)
  runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
  runner.run(1000.)

  bp.math.clear_buffer_memory()
  return runner.mon['spike'].sum()  # "output" is the spike number

Note that clear_buffer_memory() will clear all JAX arrays in the device, therefore, it’s better to give inputs as NumPy arrays, and return outputs as NumPy arrays.

current = np.linspace(1., 10., 10)

r = bp.running.cpu_ordered_parallel(hh_spike_num2, [current], num_process=10)
r
[0, 0, 1, 0, 0, 57, 60, 58, 65, 68]

If you think that the order of the running results does not matter, you can also use cpu_unordered_parallel() function. This can maximize the running efficiency of all processors, since all workers run with a non-blocking and unordered manner.

Parallelization with jax.vmap#

The second approach of realizing multi-threading parallelization is the vectorization map of JAX jax.vmap. jax.vmap vectorizes functions by compiling the mapped axis as primitive operations. It can avoid the recompilation of models in the same batch, and automatically parallelize the model running on the given machine. Following pseudocode demonstrates how simple of this parallelization approach is.

from jax import vmap

def run_model(par):
  model = YourModel(par)
  runner = bp.dyn.DSRunner(model)
  runner.run(duration)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# batch simulation through jax.vmap
r = vmap(run_model)(*all_params)

Note that if you have too many parameters to search, jax.vmap will consume too much memory. For this time, you can use our wrapped API brainpy.running.jax_vectorize_map(), which controls the running batch size by num_parallel parameter. You can set a smaller value of num_parallel when your device memory is not enough (no matter on the CPU or GPU device).

def hh_spike_num3(bg_current): # "input" is the bg_current
  model = bp.neurons.HH(1)
  runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current],
                           numpy_mon_after_run=False)
  runner.run(1000.)
  return runner.mon['spike'].sum()  # "output" is the spike number
current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3)
r
Array([ 0,  0,  0,  0,  0, 45, 60, 63, 66, 68], dtype=int32)

The function throw into the jax_vectorize_map() can not call clear_buffer_memory(). Otherwise will raise errors. Instead, uses can set clear_buffer=True/False using jax_vectorize_map(). For such kind of usage, all inputs and outputs will be automatically transformed in to NumPy arrays.

current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3, clear_buffer=True)
r
array([ 0,  1,  1,  0,  0, 57, 60, 63, 66, 68])

Parallelization across multiple devices#

BrainPy support parallelization running on multiple devices (e.g., multiple GPU devices or TPU cores) or HPC systems (e.g., supercomputers). Different from the above thread-based and processor-based parallelization methods, in which the same model runs in parallel on the same device, device-based parallelization runs the same model in parallel on multiple devices.

One way to express the multi-device parallelization of BrainPy models is using jax.pmap instruction. JAX delivers jax.pmap to express SIMD programs. It provides an interface to run the same model on multiple devices with different parameter values. It usage is analogy to jax.vmap. Following pseudocode presents an example to run BrainPy models on multiple devices.

from jax import pmap

def run_model(par):
  model = YourModel(par)
  runner = bp.dyn.DSRunner(model)
  runner.run(<int>)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# parallel simulation through jax.pmap
r = pmap(run_model)(*all_params)

jax.pmap has the similar issue to jax.vmap when you parallelize across many parameters. This time you can use the wrapped function brainpy.running.jax_parallelize_map().

If you are using pmap in you CPU device, you can set the virtual number of the device by calling brainpy.math.set_host_device_count(n). Then, you can call jax_parallelize_map() safely one your CPU platform.

bp.math.set_host_device_count(10)  # this should place on the top of the file

current = bm.linspace(1., 10.1, 20)
r = bp.running.jax_parallelize_map(hh_spike_num3, [current], num_parallel=10, clear_buffer=True)
r
array([ 0,  0,  0,  0,  0,  0,  0, 49, 52, 54, 56, 58, 59, 61, 62, 63, 65,
       66, 67, 68])

BrainPy also works well with job scheduling systems such as SLURM on a supercomputer center. Therefore, another way to express multi-device parallelization is to employ the classical resource management system. Following script demonstrates an example that submits a batch script to SLURM.


#!/bin/bash
#SBATCH -J <name>
#SBATCH -o <file name>
#SBATCH -p <str>
#SBATCH -n <int>
#SBATCH -N <int>
#SBATCH -c <int>

python your_script.py

Model Training#

This tutorial shows how to train a dynamical system from data or task.

Building Training Models#

In this section, we are going to talk about how to build models for training.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.3.1'

Use built-in models#

brainpy.dyn.DynamicalSystem provided in BrainPy can be used for model training.

mode settings#

Some built-in models have implemented the training interface for their training. Users can instantiate these models by providing the parameter mode=brainpy.modes.training for training model customization.

For example, brainpy.neurons.LIF is a model commonly used in computational simulation, but it can also be used in training.

# Instantiate a LIF model for simulation

lif = bp.neurons.LIF(1)
lif.mode
NonBatchingMode
# Instantiate a LIF model for training.
# In this mode, the model implement variables and functions
# compatible with BrainPy's training interface.

lif = bp.neurons.LIF(1, mode=bp.modes.training)
lif.mode
TrainingMode

But some build-in models does not support training.

try:
    bp.layers.NVAR(1, 1, mode=bp.modes.training)
except Exception as e:
    print(type(e), e)
<class 'NotImplementedError'> NVAR does not support TrainingMode. We only support BatchingMode, NonBatchingMode. 

The mode can be used to control the weight types. Let’s take a synaptic model for another example. For a non-trainable dense layer, the weights and bias are Array instances.

l = bp.layers.Dense(3, 4, mode=bm.batching_mode)

l.W
Array([[ 0.13143115,  0.7037631 , -0.50639415, -0.49906093],
       [-0.39095506,  0.5210247 ,  0.6293488 ,  0.7321653 ],
       [ 0.2841127 ,  0.3818757 , -0.19256772, -0.6708007 ]],      dtype=float32)
l = bp.layers.Dense(3, 4, mode=bm.training_mode)

l.W
TrainVar([[-0.36896244,  0.6050412 , -0.53849053,  0.03913487],
          [-0.78182685, -0.7611104 , -0.00870763, -0.06463569],
          [-0.2160572 ,  0.5157468 ,  0.09730986,  0.16213563]],      dtype=float32)

Moreover, for some recurrent models, e.g., LSTM or GRU, the state can be set to be trainable or not trainable by train_state argument. When setting train_state=True for the recurrent instance, a new attribute .state2train will be created.

rnn = bp.layers.RNNCell(1, 3, train_state=True)

rnn.state2train
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [7], in <cell line: 3>()
      1 rnn = bp.layers.RNNCell(1, 3, train_state=True)
----> 3 rnn.state2train

AttributeError: 'RNNCell' object has no attribute 'state2train'

Note the difference between the .state2train and the original .state:

  1. .state2train has no batch axis.

  2. When using node.reset_state() function, all values in the .state will be filled with .state2train.

rnn.reset_state(batch_size=5)
rnn.state
Naming a node#

For convenience, you can name a layer by specifying the name keyword argument:

bp.layers.Dense(128, 100, name='hidden_layer')
Initializing parameters#

Many models have their parameters. We can set the parameter of a model with the following methods.

  • Arrays

If an array is provided, this is used unchanged as the parameter variable. For example:

l = bp.layers.Dense(10, 50, W_initializer=bm.random.normal(0, 0.01, size=(10, 50)))

l.W.shape
  • Callable function

If a callable function (which receives a shape argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:

def init(shape):
    return bm.random.random(shape)

l = bp.layers.Dense(20, 30, W_initializer=init)

l.W.shape
  • Instance of brainpy.init.Initializer

If a brainpy.init.Initializer instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:

l = bp.layers.Dense(20, 30, W_initializer=bp.init.Normal(0.01))

l.W.shape

The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.init for more information).

  • None parameter

Some types of parameter variables can also be set to None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:

l = bp.layers.Dense(20, 100, b_initializer=None)

print(l.b)

Customize your models#

Customizing your training models is simple. You just need to subclass brainpy.DynamicalSystem, and implement its update() and reset_state() functions.

Here, we demonstrate the model customization using two examples. The first is a recurrent layer.

class RecurrentLayer(bp.DynamicalSystem):
    def __init__(self, num_in, num_out):
        super(RecurrentLayer, self).__init__()

        bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))

        # define parameters
        self.num_in = num_in
        self.num_out = num_out

        # define variables
        self.state = bm.Variable(bm.zeros(1, num_out), batch_axis=0)

        # define weights
        self.win = bm.TrainVar(bm.random.normal(0., 1./num_in ** 0.5, size=(num_in, num_out)))
        self.wrec = bm.TrainVar(bm.random.normal(0., 1./num_out ** 0.5, size=(num_out, num_out)))

    def reset_state(self, batch_size):
        # this function defines how to reset the mode states
        self.state.value = bm.zeros((batch_size, self.num_out))

    def update(self, sha, x):
        # this function defined how the model update its state and produce its output
        out = bm.dot(x, self.win) + bm.dot(self.state, self.wrec)
        self.state.value = bm.tanh(out)
        return self.state.value

This simple example illustrates many features essential for a training model. reset_state() function defines how to reset model states, which will be called at the first time step; update() function defines how the model states are evolving, which will be called at every time step.

Another example is the dropout layer, which can be useful to demonstrate how to define a model with multiple behaviours.

class Dropout(bp.dyn.DynamicalSystem):
  def __init__(self, prob: float, seed: int = None, name: str = None):
    super(Dropout, self).__init__(name=name)

    bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode))
    self.prob = prob
    self.rng = bm.random.RandomState(seed=seed)

  def update(self, sha, x):
    if sha.get('fit', True):
      keep_mask = self.rng.bernoulli(self.prob, x.shape)
      return bm.where(keep_mask, x / self.prob, 0.)
    else:
      return x

Here, the model makes different outputs according to the different values of a shared parameter fit.

You can define your own shared parameters, and then provide their shared parameters when calling the trainer objects (see the following section).

Examples of training models#

In the following, we illustrate several examples to build a trainable neural network model.

Artificial neural networks#

BrainPy provides neural network layers which can be useful to define artificial neural networks.

Here, let’s define a deep RNN model.

class DeepRNN(bp.dyn.DynamicalSystem):
    def __init__(self, num_in, num_recs, num_out):
        super(DeepRNN, self).__init__()

        self.l1 = bp.layers.LSTM(num_in, num_recs[0])
        self.d1 = bp.layers.Dropout(0.2)
        self.l2 = bp.layers.LSTM(num_recs[0], num_recs[1])
        self.d2 = bp.layers.Dropout(0.2)
        self.l3 = bp.layers.LSTM(num_recs[1], num_recs[2])
        self.d3 = bp.layers.Dropout(0.2)
        self.l4 = bp.layers.LSTM(num_recs[2], num_recs[3])
        self.d4 = bp.layers.Dropout(0.2)
        self.lout = bp.layers.Dense(num_recs[3], num_out)

    def update(self, sha, x):
        x = self.d1(sha, self.l1(sha, x))
        x = self.d2(sha, self.l2(sha, x))
        x = self.d3(sha, self.l3(sha, x))
        x = self.d4(sha, self.l4(sha, x))
        return self.lout(sha, x)

with bm.training_environment():
    model = DeepRNN(100, 200, 10)

Note here the difference of the model building from PyTorch is that the first argument in update() function should be the shared parameters sha (i.e., these parameters are shared across all models, like the time t, the running index i, and the model running phase fit). Then other individual arguments can all be customized by users. The details of the model definition specification can be seen in ????

Moreover, it is worthy to note that this model only defines the one step updating rule of how the model evolves according to the input x.

Reservoir computing models#

In this example, we define a reservoir computing model called next generation reservoir computing by using the built-in models provided in BrainPy.

class NGRC(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_out):
    super(NGRC, self).__init__(mode=bm.batching_mode)
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5, mode=bm.batching_mode)
    self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))

In the above model, brainpy.layers.NVAR is a nonlinear vector autoregression machine, which does not have the training features. Therefore, we define its mode as batching mode. On the contrary, brainpy.layers.Dense has the trainable weights for model training.

Spiking Neural Networks#

Building trainable spiking neural networks in BrainPy is also a piece of cake. We provided commonly used spiking models for traditional dynamics simulation. But most of them can be used for training too.

In the following, we provide an implementation of spiking neural networks in (Neftci, Mostafa, & Zenke, 2019) for surrogate gradient learning.

class SNN(bp.dyn.Network):
  def __init__(self, num_in, num_rec, num_out):
    super(SNN, self).__init__()

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in)
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=20.))
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=20.))

  def update(self, tdi, spike):
    self.i2r(tdi, spike)
    self.r2o(tdi)
    self.r(tdi)
    self.o(tdi)
    return self.o.V.value

with bm.training_environment():
    snn = SNN(10, 100, 2)

Note here the mode in all models are specified as brainpy.modes.TrainingMode.

Training with Offline Algorithms#

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import matplotlib.pyplot as plt

bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')

bp.__version__
'2.3.1'

BrainPy provides many offline training algorithms can help users train models such as reservoir computing models.

Train a reservoir model#

Here, we train an echo-state machine to predict chaotic dynamics. This example is used to illustrate how to use brainpy.train.OfflineTrainer.

We first get the training dataset.

def get_subset(data, start, end):
  res = {'x': data.xs[start: end],
         'y': data.ys[start: end],
         'z': data.zs[start: end]}
  res = bm.hstack([res['x'], res['y'], res['z']])
  # Training data must have batch size, here the batch is 1
  return res.reshape((1, ) + res.shape)
dt = 0.01
t_warmup, t_train, t_test = 5., 100., 50.  # ms
num_warmup, num_train, num_test = int(t_warmup/dt), int(t_train/dt), int(t_test/dt)
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
                                  dt=dt,
                                  inits={'x': 17.67715816276679,
                                         'y': 12.931379185960404,
                                         'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 5)
X_train = get_subset(lorenz_series, num_warmup - 5, num_warmup + num_train - 5)
X_test = get_subset(lorenz_series,
                    num_warmup + num_train - 5,
                    num_warmup + num_train + num_test - 5)
# out target data is the activity ahead of 5 time steps
Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
Y_test = get_subset(lorenz_series,
                    num_warmup + num_train,
                    num_warmup + num_train + num_test)

Then, we try to build an echo-state machine to predict the chaotic dynamics ahead of five time steps.

class ESN(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_hidden, num_out):
    super(ESN, self).__init__()
    self.r = bp.layers.Reservoir(num_in, num_hidden,
                                 Win_initializer=bp.init.Uniform(-0.1, 0.1),
                                 Wrec_initializer=bp.init.Normal(scale=0.1),
                                 in_connectivity=0.02,
                                 rec_connectivity=0.02,
                                 comp_type='dense')
    self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal(),
                             mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))
model = ESN(3, 100, 3)

Here, we use ridge regression as the training algorithm to train the chaotic model.

trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.RidgeRegression(1e-7), dt=dt)
# first warmup the reservoir

_ = trainer.predict(X_warmup)
# then fit the reservoir model

_ = trainer.fit([X_train, Y_train])
def plot_lorenz(ground_truth, predictions):
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(121, projection='3d')
  ax.set_title("Generated attractor")
  ax.set_xlabel("$x$")
  ax.set_ylabel("$y$")
  ax.set_zlabel("$z$")
  ax.grid(False)
  ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.set_title("Real attractor")
  ax2.grid(False)
  ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
  plt.show()
# finally, predict the model with the test data

outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS:  0.041717741900418666
_images/2b80927a168915b443d208cd1f0ac2c7ec8131f743fcf41a804ff4fb33378a78.png

Switch different training algorithms#

brainpy.train.OfflineTrainer supports easy switch of training algorithms. You just need provide the fit_method argument when instantiating an offline trainer.

Many offline algorithms, like linear regression, ridge regression, and Lasso regression, have been provided as the build-in models.

model = ESN(3, 100, 3)
model.reset_state(1)
trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.LinearRegression())

_ = trainer.predict(X_warmup)
_ = trainer.fit([X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
_images/e2fada8f89ef21ee5aa2c474f51e525e905c0f53f2b51d95e937b27559806f6f.png

Customize your training algorithms#

brainpy.train.OfflineTrainer also supports to train models with your customized training algorithms.

Specifically, the customization of an offline algorithm should follow the interface of brainpy.algorithms.OfflineAlgorithm, in which users specify how the model parameters are calculated according to the input, prediction, and target data.

For instance, here we use the Lasso model provided in scikit-learn package to define an offline training algorithm.

from sklearn.linear_model import Lasso

class LassoAlgorithm(bp.algorithms.OfflineAlgorithm):
  def __init__(self, alpha=1., max_iter=int(1e4)):
    super(LassoAlgorithm, self).__init__()
    self.model = Lasso(alpha=alpha, max_iter=max_iter)

  def __call__(self, identifier, y, x, outs=None):
    x = bm.as_numpy(x[0])
    y = bm.as_numpy(y[0])
    x_new = self.model.fit(x, y).coef_.T
    return bm.expand_dims(bm.asarray(x_new), 1)
model = ESN(3, 100, 3)
model.reset_state(1)

# note here scikit-learn algorithms does not support JAX jit,
# therefore the "jit" of the "fit" phase is set to be False.
trainer = bp.train.OfflineTrainer(model, fit_method=bp.algorithms.LinearRegression(),
                                  jit={'fit': False})

_ = trainer.predict(X_warmup)
_ = trainer.fit([X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
_images/cee42599780665d39db95dfea35084f6b621b5ccc66e7553dc7e2adbf73471ce.png

Training with Online Algorithms#

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
import brainpy_datasets as bd

bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')

bp.__version__
'2.3.1'

Online training algorithms, such as FORCE learning, have played vital roles in brain modeling. BrainPy provides brainpy.train.OnlineTrainer for model training with online algorithms.

Train a reservoir model#

Here, we are going to use brainpy.train.OnlineTrainer to train a next generation reservoir computing model (NGRC) to predict chaotic dynamics.

We first get the training dataset.

def get_subset(data, start, end):
  res = {'x': data.xs[start: end],
         'y': data.ys[start: end],
         'z': data.zs[start: end]}
  res = bm.hstack([res['x'], res['y'], res['z']])
  # Training data must have batch size, here the batch is 1
  return res.reshape((1, ) + res.shape)
dt = 0.01
t_warmup, t_train, t_test = 5., 100., 50.  # ms
num_warmup, num_train, num_test = int(t_warmup/dt), int(t_train/dt), int(t_test/dt)
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
                                  dt=dt,
                                  inits={'x': 17.67715816276679,
                                         'y': 12.931379185960404,
                                         'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 5)
X_train = get_subset(lorenz_series, num_warmup - 5, num_warmup + num_train - 5)
X_test = get_subset(lorenz_series,
                    num_warmup + num_train - 5,
                    num_warmup + num_train + num_test - 5)
# out target data is the activity ahead of 5 time steps
Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
Y_test = get_subset(lorenz_series,
                    num_warmup + num_train,
                    num_warmup + num_train + num_test)

Then, we try to build a NGRC model to predict the chaotic dynamics ahead of five time steps.

class NGRC(bp.dyn.DynamicalSystem):
  def __init__(self, num_in):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True)
    self.o = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None, mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))
model = NGRC(3)
model.reset_state(1)

Here, we use ridge regression as the training algorithm to train the chaotic model.

trainer = bp.train.OnlineTrainer(model, fit_method=bp.algorithms.RLS(), dt=dt)
# first warmup the reservoir

_ = trainer.predict(X_warmup)
# then fit the reservoir model

_ = trainer.fit([X_train, Y_train])
def plot_lorenz(ground_truth, predictions):
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(121, projection='3d')
  ax.set_title("Generated attractor")
  ax.set_xlabel("$x$")
  ax.set_ylabel("$y$")
  ax.set_zlabel("$z$")
  ax.grid(False)
  ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.set_title("Real attractor")
  ax2.grid(False)
  ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
  plt.show()
# finally, predict the model with the test data

outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS:  0.0007826608198954596
_images/df1a65dd55ed60212714cc4776b745055c33e7d61a141145679b2ac73f12f8bf.png

Training with Back-propagation Algorithms#

Back-propagation (BP) trainings have become foundations in machine learning algorithms. In this section, we are going to talk about how to train models with BP.

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import numpy as np

bm.set_mode(bm.training_mode)
bm.set_platform('cpu')

bp.__version__
'2.3.1'

Here, we train two kinds of models to classify MNIST dataset. The first is ANN models commonly used in deep neural networks. The second is SNN models.

Train a ANN model#

We first build a three layer ANN model:

i >> r >> o

where the recurrent layer r is a LSTM cell, the output o is a linear readout.

class ANNModel(bp.dyn.DynamicalSystem):
    def __init__(self, num_in, num_rec, num_out):
        super(ANNModel, self).__init__()
        self.rec = bp.layers.LSTMCell(num_in, num_rec)
        self.out = bp.layers.Dense(num_rec, num_out)

    def update(self, sha, x):
        x = self.rec(sha, x)
        x = self.out(sha, x)
        return x

Before training this model, we get and clean the data we want.

root = r"D:\data"
train_dataset = bd.vision.FashionMNIST(root, split='train', download=True)
test_dataset = bd.vision.FashionMNIST(root, split='test', download=True)


def get_data(dataset, batch_size=256):
  rng = bm.random.clone_rng()

  def data_generator():
    X = bm.array(dataset.data, dtype=bm.float_) / 255
    Y = bm.array(dataset.targets, dtype=bm.float_)
    key = rng.split_key()
    rng.shuffle(X, key=key)
    rng.shuffle(Y, key=key)
    for i in range(0, len(dataset), batch_size):
      yield X[i: i + batch_size], Y[i: i + batch_size]

  return data_generator

Then, we start to train our defined ANN model with brainpy.train.BPTT training interface.

# model
model = ANNModel(28, 100, 10)

# loss function
def loss_fun(predicts, targets):
    predicts = bm.max(predicts, axis=1)
    loss = bp.losses.cross_entropy_loss(predicts, targets)
    acc = bm.mean(predicts.argmax(axis=-1) == targets)
    return loss, {'acc': acc}

# optimizer
optimizer=bp.optim.Adam(lr=1e-3)

# trainer
trainer = bp.train.BPTT(model,
                        loss_fun=loss_fun,
                        loss_has_aux=True,
                        optimizer=optimizer)
trainer.fit(train_data=get_data(train_dataset, 256),
            test_data=get_data(test_dataset, 512),
            num_epoch=10)
Train 0 epoch, use 11.7911 s, loss 0.9686446785926819, acc 0.657629668712616
Test 0 epoch, use 1.1449 s, loss 0.6075307726860046, acc 0.7804630398750305
Train 1 epoch, use 9.1454 s, loss 0.5276323556900024, acc 0.8128490447998047
Test 1 epoch, use 0.2927 s, loss 0.5302323698997498, acc 0.8131089210510254
Train 2 epoch, use 9.0979 s, loss 0.4588756263256073, acc 0.8355662226676941
Test 2 epoch, use 0.3037 s, loss 0.4678855538368225, acc 0.8310604095458984
Train 3 epoch, use 9.1892 s, loss 0.42316296696662903, acc 0.8461214303970337
Test 3 epoch, use 0.2911 s, loss 0.45249971747398376, acc 0.8364487886428833
Train 4 epoch, use 9.0769 s, loss 0.3961907625198364, acc 0.8553800582885742
Test 4 epoch, use 0.2947 s, loss 0.4217829704284668, acc 0.8458294868469238
Train 5 epoch, use 8.9839 s, loss 0.3784363567829132, acc 0.8621509075164795
Test 5 epoch, use 0.3015 s, loss 0.41539546847343445, acc 0.8533375859260559
Train 6 epoch, use 8.9756 s, loss 0.362664133310318, acc 0.8676861524581909
Test 6 epoch, use 0.3095 s, loss 0.3904822766780853, acc 0.8592026829719543
Train 7 epoch, use 9.0243 s, loss 0.34826964139938354, acc 0.8724401593208313
Test 7 epoch, use 0.2876 s, loss 0.3742746412754059, acc 0.8639591336250305
Train 8 epoch, use 9.0806 s, loss 0.3381759822368622, acc 0.8756925463676453
Test 8 epoch, use 0.2951 s, loss 0.3876565992832184, acc 0.8559110760688782
Train 9 epoch, use 9.1019 s, loss 0.32923951745033264, acc 0.8797928094863892
Test 9 epoch, use 0.2833 s, loss 0.3779725432395935, acc 0.8602309226989746
import matplotlib.pyplot as plt

plt.plot(trainer.get_hist_metric('fit'), label='fit')
plt.plot(trainer.get_hist_metric('test'), label='train')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
_images/a999c8c3164f9479f03829422da3bc2680b4a98401e7bfcba7579099012bcdeb.png

Train a SNN model#

Similarly, brainpy.train.BPTT can also be used to train SNN models.

We first build a three layer SNN model:

i >> [exponential synapse] >> r >> [exponential synapse] >> o
class SNNModel(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_rec, num_out):
    super(SNNModel, self).__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in)
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(),
                                       tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=2.))
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(),
                                       tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=2.))

  def update(self, shared, spike):
    self.i2r(shared, spike)
    self.r2o(shared)
    self.r(shared)
    self.o(shared)
    return self.o.V.value

As the model receives spiking inputs, we define functions that are necessary to transform the continuous values to spiking data.

def current2firing_time(x, tau=20., thr=0.2, tmax=1.0, epsilon=1e-7):
  x = np.clip(x, thr + epsilon, 1e9)
  T = tau * np.log(x / (x - thr))
  T = np.where(x < thr, tmax, T)
  return T

def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True):
  labels_ = np.array(y, dtype=bm.int_)
  sample_index = np.arange(len(X))

  # compute discrete firing times
  tau_eff = 2. / bm.get_dt()
  unit_numbers = np.arange(nb_units)
  firing_times = np.array(current2firing_time(X, tau=tau_eff, tmax=nb_steps), dtype=bm.int_)

  if shuffle:
    np.random.shuffle(sample_index)

  counter = 0
  number_of_batches = len(X) // batch_size
  while counter < number_of_batches:
    batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]
    all_batch, all_times, all_units = [], [], []
    for bc, idx in enumerate(batch_index):
      c = firing_times[idx] < nb_steps
      times, units = firing_times[idx][c], unit_numbers[c]
      batch = bc * np.ones(len(times), dtype=bm.int_)
      all_batch.append(batch)
      all_times.append(times)
      all_units.append(units)
    all_batch = np.concatenate(all_batch).flatten()
    all_times = np.concatenate(all_times).flatten()
    all_units = np.concatenate(all_units).flatten()
    x_batch = bm.zeros((batch_size, nb_steps, nb_units))
    x_batch[all_batch, all_times, all_units] = 1.
    y_batch = bm.asarray(labels_[batch_index])
    yield x_batch, y_batch
    counter += 1

Now, we can define a BP trainer for this SNN model.

def loss_fun(predicts, targets):
    predicts, mon = predicts
    # L1 loss on total number of spikes
    l1_loss = 1e-5 * bm.sum(mon['r.spike'])
    # L2 loss on spikes per neuron
    l2_loss = 1e-5 * bm.mean(bm.sum(bm.sum(mon['r.spike'], axis=0), axis=0) ** 2)
    # predictions
    predicts = bm.max(predicts, axis=1)
    loss = bp.losses.cross_entropy_loss(predicts, targets)
    acc = bm.mean(predicts.argmax(-1) == targets)
    return loss + l2_loss + l1_loss, {'acc': acc}

model = SNNModel(num_in=28*28, num_rec=100, num_out=10)

trainer = bp.train.BPTT(
    model,
    loss_fun=loss_fun,
    loss_has_aux=True,
    optimizer=bp.optim.Adam(lr=1e-3),
    monitors={'r.spike': model.r.spike},
)

The training process is similar to that of the ANN model, instead of the data is generated by the sparse generator function we defined above.

x_train = bm.array(train_dataset.data, dtype=bm.float_) / 255
y_train = bm.array(train_dataset.targets, dtype=bm.int_)

trainer.fit(lambda: sparse_data_generator(x_train.reshape(x_train.shape[0], -1),
                                          y_train,
                                          batch_size=256,
                                          nb_steps=100,
                                          nb_units=28 * 28),
            num_epoch=10)
Train 0 epoch, use 56.6148 s, loss 10.524602890014648, acc 0.3441840410232544
Train 1 epoch, use 48.7201 s, loss 1.947080373764038, acc 0.4961271286010742
Train 2 epoch, use 50.2106 s, loss 1.5027152299880981, acc 0.5980067849159241
Train 3 epoch, use 53.0944 s, loss 1.371555209159851, acc 0.63353031873703
Train 4 epoch, use 54.2528 s, loss 1.294083833694458, acc 0.6476696133613586
Train 5 epoch, use 56.5207 s, loss 1.2385631799697876, acc 0.6586705446243286
Train 6 epoch, use 61.7909 s, loss 1.2144725322723389, acc 0.6649806499481201
Train 7 epoch, use 72.7359 s, loss 1.1915594339370728, acc 0.6712072491645813
Train 8 epoch, use 76.2446 s, loss 1.153993010520935, acc 0.6776843070983887
Train 9 epoch, use 79.4869 s, loss 1.1312021017074585, acc 0.682542085647583
plt.plot(trainer.get_hist_metric('fit'), label='fit')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
_images/8f30c85bb7413742434179c15deec274d96bae2351eaed83a32239a0adf9ae0b.png

Customize your BP training#

Actually, brainpy.train.BPTT is just one way to perform back-propagation training with your model. You can easily customize your training process.

In the below, we demonstrate how to define a BP training process by hand with the above ANN model.

# packages we need

from time import time
# define the model
model = ANNModel(28, 100, 10)
# define the loss function
@bm.to_object(child_objs=model)
def loss_fun(inputs, targets):
  runner = bp.train.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
  predicts = runner.predict(inputs, reset_state=True)
  predicts = bm.max(predicts, axis=1)
  loss = bp.losses.cross_entropy_loss(predicts, targets)
  acc = bm.mean(predicts.argmax(-1) == targets)
  return loss, acc
# define the gradient function which computes the
# gradients of the trainable weights
grad_fun = bm.grad(loss_fun,
                   grad_vars=model.train_vars().unique(),
                   has_aux=True,
                   return_value=True)
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
# training function

@bm.jit
@bm.to_object(child_objs=(opt, grad_fun))
def train(xs, ys):
  grads, loss, acc = grad_fun(xs, ys)
  opt.update(grads)
  return loss, acc
# start training

k = 0
num_batch = 256
running_loss = 0
running_acc = 0
print_step = 100
X_train = bm.asarray(x_train)
Y_train = bm.asarray(y_train)
t0 = time()
for _ in range(10):  # number of epoch
  X_train = bm.random.permutation(X_train, key=123)
  Y_train = bm.random.permutation(Y_train, key=123)

  for i in range(0, X_train.shape[0], num_batch):
    X = X_train[i: i + num_batch]
    Y = Y_train[i: i + num_batch]
    loss_, acc_ = train(X, Y)
    running_loss += loss_
    running_acc += acc_
    k += 1
    if k % print_step == 0:
      print('Step {}, Used {:.4f} s, Loss {:0.4f}, Acc {:0.4f}'.format(
        k, time() - t0,  running_loss / print_step, running_acc / print_step)
      )
      t0 = time()
      running_loss = 0
      running_acc = 0
Step 100, Used 6.7523 s, Loss 1.2503, Acc 0.5630
Step 200, Used 5.3020 s, Loss 0.6340, Acc 0.7779
Step 300, Used 6.5825 s, Loss 0.5545, Acc 0.8056
Step 400, Used 5.3013 s, Loss 0.5028, Acc 0.8198
Step 500, Used 5.3458 s, Loss 0.4659, Acc 0.8340
Step 600, Used 5.3190 s, Loss 0.4601, Acc 0.8316
Step 700, Used 5.2990 s, Loss 0.4297, Acc 0.8443
Step 800, Used 5.3577 s, Loss 0.4244, Acc 0.8456
Step 900, Used 5.3054 s, Loss 0.4053, Acc 0.8538
Step 1000, Used 5.3404 s, Loss 0.3913, Acc 0.8568
Step 1100, Used 5.2744 s, Loss 0.3943, Acc 0.8534
Step 1200, Used 5.4739 s, Loss 0.3863, Acc 0.8592
Step 1300, Used 5.4073 s, Loss 0.3709, Acc 0.8647
Step 1400, Used 5.3310 s, Loss 0.3791, Acc 0.8607
Step 1500, Used 5.3793 s, Loss 0.3644, Acc 0.8643
Step 1600, Used 5.3164 s, Loss 0.3562, Acc 0.8718
Step 1700, Used 5.4404 s, Loss 0.3585, Acc 0.8677
Step 1800, Used 5.4584 s, Loss 0.3533, Acc 0.8716
Step 1900, Used 5.4216 s, Loss 0.3460, Acc 0.8727
Step 2000, Used 5.4207 s, Loss 0.3445, Acc 0.8729
Step 2100, Used 5.3493 s, Loss 0.3375, Acc 0.8749
Step 2200, Used 5.3991 s, Loss 0.3317, Acc 0.8773
Step 2300, Used 5.3003 s, Loss 0.3356, Acc 0.8755

Introduction to Echo State Network#

@Chaoming Wang

import brainpy as bp
import brainpy.math as bm

# enable x64 computation
bm.set_environment(x64=True, mode=bm.batching_mode)
# bm.set_platform('cpu')

bp.__version__
'2.3.1'
import brainpy_datasets as bd
bd.__version__
'0.0.0.2'
import matplotlib.pyplot as plt

Echo State Network#

Echo State Networks (ESNs) are applied to supervised temporal machine learning tasks where for a given training input signal \(x(n)\) a desired target output signal \(y^{target}(n)\) is known. Here \(n=1, ..., T\) is the discrete time and \(T\) is the number of data points in the training dataset.

The task is to learn a model with output \(y(n)\), where \(y(n)\) matches \(y^{target}(n)\) as well as possible, minimizing an error measure \(E(y, y^{target})\), and, more importantly, generalizes well to unseen data.

ESNs use an RNN type with leaky-integrated discrete-time continuous-value units. The typical update equations are

\[\begin{split} \hat{h}(n) = \tanh(W^{in} x(n) + W^{rec}h(n-1) + W^{fb}y(n-1) + b^{rec}) \\ h(n) = (1 - \alpha) x(n-1)+\alpha \hat{h}(n) \end{split}\]

where \(h(n)\) is a vector of reservoir neuron activations, \(W^{in}\) and \(W^{rec}\) are the input and recurrent weight matrices respectively, and \(\alpha \in (0, 1]\) is the leaking rate. The model is also sometimes used without the leaky integration, which is a special case of \(\alpha=1\).

The linear readout layer is defined as

\[ y(n) = W^{out} h(n) + b^{out} \]

where \(y(n)\) is network output, \(W^{out}\) the output weight matrix, and \(b^{out}\) is the output bias.

An additional nonlinearity can be applied to \(y(n)\), as well as feedback connections \(W^{fb}\) from \(y(n-1)\) to \(\hat{h}(n)\).

A graphical representation of an ESN illustrating our notation and the idea for training is depicted in the following figure.

echo state machine

Ridge regression#

Finding the optimal weights \(W^{out}\) that minimize the squared error between \(y(n)\) and \(y^{target}(n)\) amounts to solving a typically overdetermined system of linear equations

\[Y^{target} = W^{out}X\]

Probably the most universal and stable solution is ridge regression, also known as regression with Tikhonov regularization:

\[W^{out} = Y^{target}X^T(XX^T+\beta I)^{-1}\]

Dataset#

Mackey-Glass equation are a set of delayed differential equations describing the temporal behaviour of different physiological signal, for example, the relative quantity of mature blood cells over time.

The equations are defined as:

\[ \frac{dP(t)}{dt} = \frac{\beta P(t - \tau)}{1 + P(t - \tau)^n} - \gamma P(t) \]

where \(\beta = 0.2\), \(\gamma = 0.1\), \(n = 10\), and the time delay \(\tau = 17\). \(\tau\) controls the chaotic behaviour of the equations (the higher it is, the more chaotic the timeserie becomes. \(\tau=17\) already gives good chaotic results.)

def plot_mackey_glass_series(ts, x_series, x_tau_series, num_sample):
  plt.figure(figsize=(13, 5))

  plt.subplot(121)
  plt.title(f"Timeserie - {num_sample} timesteps")
  plt.plot(ts[:num_sample], x_series[:num_sample], lw=2, color="lightgrey", zorder=0)
  plt.scatter(ts[:num_sample], x_series[:num_sample], c=ts[:num_sample], cmap="viridis", s=6)
  plt.xlabel("$t$")
  plt.ylabel("$P(t)$")

  ax = plt.subplot(122)
  ax.margins(0.05)
  plt.title(f"Phase diagram: $P(t) = f(P(t-\\tau))$")
  plt.plot(x_tau_series[: num_sample], x_series[: num_sample], lw=1, color="lightgrey", zorder=0)
  plt.scatter(x_tau_series[:num_sample], x_series[: num_sample], lw=0.5, c=ts[:num_sample], cmap="viridis", s=6)
  plt.xlabel("$P(t-\\tau)$")
  plt.ylabel("$P(t)$")
  cbar = plt.colorbar()
  cbar.ax.set_ylabel('$t$')

  plt.tight_layout()
  plt.show()

An easy way to get Mackey-Glass time-series data is using brainpy.dataset.mackey_glass_series(). If you want to see the details of the implementation, please see the corresponding source code.

dt = 0.1
mg_data = bd.chaos.MackeyGlassEq(25000, dt=dt, tau=17, beta=0.2, gamma=0.1, n=10)
ts = mg_data.ts
xs = mg_data.xs
ys = mg_data.ys
plot_mackey_glass_series(ts, xs, ys, num_sample=int(1000 / dt))
_images/a4d0f937c6f01fd1159403680e73ad6696b6be10f4add9965b9cfe6695cd8d7f.png

Task 1: prediction of Mackey-Glass timeseries#

Predict \(P(t+1), \cdots, P(t+N)\) from \(P(t)\).

Prepare the data#
def get_data(t_warm, t_forcast, t_train, sample_rate=1):
    warmup = int(t_warm / dt)  # warmup the reservoir
    forecast = int(t_forcast / dt)  # predict 10 ms ahead
    train_length = int(t_train / dt)

    X_warm = xs[:warmup:sample_rate]
    X_warm = bm.expand_dims(X_warm, 0)

    X_train = xs[warmup: warmup+train_length: sample_rate]
    X_train = bm.expand_dims(X_train, 0)

    Y_train = xs[warmup+forecast: warmup+train_length+forecast: sample_rate]
    Y_train = bm.expand_dims(Y_train, 0)

    X_test = xs[warmup + train_length: -forecast: sample_rate]
    X_test = bm.expand_dims(X_test, 0)

    Y_test = xs[warmup + train_length + forecast::sample_rate]
    Y_test = bm.expand_dims(Y_test, 0)

    return X_warm, X_train, Y_train, X_test, Y_test
# First warmup the reservoir using the first 100 ms
# Then, train the network in 20000 ms to predict 1 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(100, 1, 20000)
sample = 3000
fig = plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()
_images/c82f966fc24f8ac099adbf566cb59464ea2fd6dfaa86591b61d0073af29eb733.png
Prepare the ESN#
class ESN(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_hidden, num_out, sr=1., leaky_rate=0.3,
               Win_initializer=bp.init.Uniform(0, 0.2)):
    super(ESN, self).__init__()
    self.r = bp.layers.Reservoir(
        num_in, num_hidden,
        Win_initializer=Win_initializer,
        spectral_radius=sr,
        leaky_rate=leaky_rate,
    )
    self.o = bp.layers.Dense(num_hidden, num_out, mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))
Train and test ESN#
model = ESN(1, 100, 1)
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-6)
# warmup
_ = trainer.predict(x_warm)
# train
_ = trainer.fit([x_train, y_train])

Test the training data.

ys_predict = trainer.predict(x_train)
start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0],
         lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_train)[0, start:end, 0], linestyle="--",
         lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_train)}')
plt.legend()
plt.show()
_images/adc15aa9f399315760897dbe54f2e2975df193a489893cc40a3b72f4ebd87708.png

Test the testing data.

ys_predict = trainer.predict(x_test)

start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0], lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_test)[0,start:end, 0], linestyle="--", lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_test)}')
plt.legend()
plt.show()
_images/38d71c742794c352d5196318110dd598101dc4f1774c34ce5a0481cd1899fe92.png
Make the task harder#
# First warmup the reservoir using the first 100 ms
# Then, train the network in 20000 ms to predict 10 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(100, 10, 20000)
sample = 3000
plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()
_images/a8aa7878861dc3c0e9d0a65be6aac25b67538f60abad6bcd75ba15a73e13a6f9.png
model = ESN(1, 100, 1, sr=1.1)
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-6)
# warmup
_ = trainer.predict(x_warm)

# train
_ = trainer.fit([x_train, y_train])
ys_predict = trainer.predict(x_test, )

start, end = 1000, 6000
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0], lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_test)[0, start:end, 0], linestyle="--", lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_test)}')
plt.legend()
plt.show()
_images/533ff3aa5f397c41a7e31ad2f846d0da3d383a987ebdd4db5ad4c295ff448ac8.png
Diving into the reservoir#

Let’s have a look at the effect of some of the hyperparameters of the ESN.

Spectral radius#

The spectral radius is defined as the maximum eigenvalue of the reservoir matrix.

num_sample = 20
all_radius = [-0.5, 0.5, 1.25, 2.5, 10.]

plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_radius):
  model = ESN(1, 100, 1, sr=s)
  model.reset_state(1)
  runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
  _ = runner.predict(x_test[:, :10000])
  states = bm.as_numpy(runner.mon['state'])

  plt.subplot(len(all_radius), 1, i + 1)
  plt.plot(states[0, :, :num_sample])
  plt.ylabel(f"spectral radius=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()
_images/2b590620db88680bfdddf50b3a1b93b939e6dddb680b9c1c54015f3ee7efb171.png
  • spectral radius < 1 \(\rightarrow\) stable dynamics

  • spectral radius > 1 \(\rightarrow\) chaotic dynamics

In most cases, it should have a value around \(1.0\) to ensure the echo state property (ESP): the dynamics of the reservoir should not be bound to the initial state chosen, and remains close to chaos.

This value also heavily depends on the input scaling.

Input scaling#

The input scaling controls how the ESN interact with the inputs. It is a coefficient appliyed to the input matrix \(W^{in}\).

num_sample = 20
all_input_scaling = [0.1, 1.0, 10.0]

plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_input_scaling):
  model = ESN(1, 100, 1, sr=1., Win_initializer=bp.init.Uniform(max_val=s))
  model.reset_state(1)
  runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
  _ = runner.predict(x_test[:, :10000])
  states = bm.as_numpy(runner.mon['state'])

  plt.subplot(len(all_radius), 1, i + 1)
  plt.plot(states[0, :, :num_sample])
  plt.ylabel(f"input scaling=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()
_images/4c9541378dadec1159f67c1078ff3047c81cc2ad5abaabb553770c14172c7110.png
Leaking rate#

The leaking rate (\(\alpha\)) controls the “memory feedback” of the ESN. The ESN states are indeed computed as:

\[ h(t+1) = \underbrace{\color{red}{(1 - \alpha)} h(t)}_{\text{previous states}} + \underbrace{\color{red}\alpha f(x(t+1), h(t))}_{\text{new states}} \]

where \(h\) is the state, \(x\) is the input data, \(f\) is the ESN model function, defined as:

\[ f(x, h) = \tanh(W^{in} \cdotp x + W^{rec} \cdotp h) \]

\(\alpha\) must be in \([0, 1]\).

num_sample = 20
all_rates = [0.001, 0.01, 0.1, 1.]

plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_rates):
  model = ESN(1, 100, 1, sr=1., leaky_rate=s,
              Win_initializer=bp.init.Uniform(max_val=1.), )
  model.reset_state(1)
  runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
  _ = runner.predict(x_test[:, :10000])
  states = bm.as_numpy(runner.mon['state'])

  plt.subplot(len(all_radius), 1, i + 1)
  plt.plot(states[0, :, :num_sample])
  plt.ylabel(f"leaky rate=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()
_images/444bb6bc7970e13bafd70fce6e57edfd6c5804a3d451f634e6436d91c55f5ae8.png

Let’s reduce the input influence to see what is happening inside the reservoir (input scaling set to 0.2):

num_sample = 20
all_rates = [0.001, 0.01, 0.1, 1.]

plt.figure(figsize=(15, len(all_radius) * 3))
for i, s in enumerate(all_rates):
  model = ESN(1, 100, 1, sr=1., leaky_rate=s,
              Win_initializer=bp.init.Uniform(max_val=.2), )
  model.reset_state(1)
  runner = bp.train.DSTrainer(model, monitors={'state': model.r.state})
  _ = runner.predict(x_test[:, :10000])
  states = bm.as_numpy(runner.mon['state'])

  plt.subplot(len(all_radius), 1, i + 1)
  plt.plot(states[0, :, :num_sample])
  plt.ylabel(f"leaky rate=${all_radius[i]}$")
plt.xlabel(f"States ({num_sample} neurons)")
plt.show()
_images/84e1b96b2e6fa13b6cb940ef7f572821cd92643631498eb7646f351c113549c4.png
  • high leaking rate \(\rightarrow\) low inertia, little memory of previous states

  • low leaking rate \(\rightarrow\) high inertia, big memory of previous states

The leaking rate can be seen as the inverse of the reservoir’s time contant.

Task 2: generation of Mackey-Glass timeseries#

Generative mode: the output of ESN will be used as the input.

During this task, the ESN is trained to make a short forecast of the timeserie (1 timestep ahead). Then, it will be asked to run on its own outputs, trying to predict its own behaviour.

# First warmup the reservoir using the first 500 ms
# Then, train the network in 20000 ms to predict 1 ms chaotic series ahead
x_warm, x_train, y_train, x_test, y_test = get_data(500, 1, 20000, sample_rate=int(1/dt))
sample = 300
fig = plt.figure(figsize=(15, 5))
plt.plot(x_train[0, :sample], label="Training data")
plt.plot(y_train[0, :sample], label="True prediction")
plt.legend()
plt.show()
_images/87f4107869067a388601f14dd5fb322070d20697c7c0cb323b5d28dd488177e0.png
model = ESN(1, 100, 1, sr=1.1, Win_initializer=bp.init.Uniform(max_val=.2), )
model.reset_state(1)
trainer = bp.train.RidgeTrainer(model, alpha=1e-7)
# warmup
_ = trainer.predict(x_warm)

# train
trainer.fit([x_train, y_train])

# test
ys_predict = trainer.predict(x_train)
start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.as_numpy(ys_predict)[0, start:end, 0],
         lw=3, label="ESN prediction")
plt.plot(bm.as_numpy(y_train)[0, start:end, 0], linestyle="--",
         lw=2, label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(ys_predict, y_train)}')
plt.legend()
plt.show()
_images/4d6872e02f828c0773c1c52d87452392edd21a3800bdbf1ea33491c088d96bab.png
jit_model = bm.jit(model)
outputs = [x_test[:, 0]]
truths = [x_test[:, 1]]
for i in range(200):
    outputs.append(jit_model(dict(), outputs[-1]))
    truths.append(x_test[:, i+2])
outputs = bm.asarray(outputs)
truths = bm.asarray(truths)
plt.figure(figsize=(15, 10))
plt.plot(bm.as_numpy(truths).squeeze()[:200], label='Ground truth')
plt.plot(bm.as_numpy(outputs).squeeze()[:200], label='Prediction')
plt.legend()
plt.show()
_images/fa101605828f0312bf3b93cb2609b23fab31f9c20570190f140fc69a55cfee70.png

References#

  • Jaeger, H.: The “echo state” approach to analysing and training recurrent neural networks. Technical Report GMD Report 148, German National Research Center for Information Technology (2001)

  • Lukoševičius, Mantas. “A Practical Guide to Applying Echo State Networks.” Neural Networks: Tricks of the Trade (2012).

Model Analysis#

Low-dimensional Analyzers#

@Chaoming Wang

We have talked about model simulation and training for dynamical systems with BrainPy. In this tutorial, we are going to dive into how to perform automatic analysis for your defined systems.

As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:

  1. phase plane analysis;

  2. codimension 1 or codimension 2 bifurcation analysis;

  3. bifurcation analysis of the fast-slow system.

BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.

import brainpy as bp
import brainpy.math as bm

bm.enable_x64()  # It's better to enable x64 when performing analysis
bm.set_platform('cpu')

bp.__version__
'2.3.0'
import numpy as np
import matplotlib.pyplot as plt

A simple case#

Here we test BrainPy with a simple case:

\[ \frac{dx}{dt} = \mathrm{sin}(x) + I, \]

where \(x \in [-10, 10]\).

As known to us all, this function has multiple fixed points (\(\frac{dx}{dt} = 0\)) when \(I=0\).

xs = np.arange(-10, 10, 0.01)

plt.plot(xs, np.sin(xs))
plt.scatter([-3*np.pi, -1*np.pi, 1*np.pi, 3 * np.pi], np.zeros(4), s=80, edgecolors='y')
plt.scatter([-2*np.pi, 0, 2*np.pi], np.zeros(3), s=80, facecolors='none', edgecolors='r')
plt.axhline(0)
plt.show()
_images/bee937252d7b426ada46deeaf78c6b87e24d0f89205325af5ee78d03574b1e7e.png

According to the dynamical theory, at the red hollow points, they are unstable; and for the solid ones, they are stable points.

Now let’s come back to BrainPy, and test whether BrainPy can give us the right answer.

As the analysis interfaces in BrainPy only receives ODEIntegrator or instance of DynamicalSystem, we first define an integrator with BrainPy (if you want to know how to define an ODE integrator, please refer to the tutorial of Numerical Solvers for ODEs):

@bp.odeint
def int_x(x, t, Iext):
    return bp.math.sin(x) + Iext

This is a one-dimensional dynamical system. So we are trying to use brainpy.analysis.PhasePlane1D for phase plane analysis. The usage of phase plane analysis will be detailed in the following section. Now, we just focus on the following four arguments:

  • model: It specifies the target system to analyze. It can be a list/tuple of ODEIntegrator. However, it can also be an instance of DynamicalSystem. For DynamicalSystem argument, we will use model.ints().subset(bp.ode.ODEIntegrator) to retrieve all instances of ODEIntegrator later.

  • target_vars: It specifies the variables to analyze. It must be a dict with the format of <var_name, var_interval>, where var_name is the variable name, and var_interval is the boundary of this variable.

  • pars_update: Parameters to update.

  • resolutions: The resolution to evaluate the fixed points.

Let’s try it.

pp = bp.analysis.PhasePlane1D(
  model=int_x,
  target_vars={'x': [-10, 10]},
  pars_update={'Iext': 0.},
  resolutions={'x': 0.01}
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)
I am creating the vector field ...
I am searching fixed points ...
Fixed point #1 at x=-9.424777960769386 is a stable point.
Fixed point #2 at x=-6.283185307179586 is a unstable point.
Fixed point #3 at x=-3.1415926535897984 is a stable point.
Fixed point #4 at x=3.552755127361717e-18 is a unstable point.
Fixed point #5 at x=3.1415926535897984 is a stable point.
Fixed point #6 at x=6.283185307179586 is a unstable point.
Fixed point #7 at x=9.424777960769386 is a stable point.
_images/bb4a6e51ecedf8eb6f2b6d2e166cd66fe92382ce1bb439a7f3bc49e7224a9908.png

Yeah, absolutelty, brainpy.analysis.PhasePlane1D gives us the right fixed points, and correctly evaluates the stability of these fixed points.

Phase plane is important, because it gives us the intuitive understanding how the system evolves with the given parameters. However, in most cases where we care about how the parameters affect the system behaviors, we should make bifurcation analysis. brainpy.analysis.Bifurcation1D is a convenient interface to help you get the insights of how the dynamics of a 1D system changes with parameters.

Similar to brainpy.analysis.PhasePlane1D, brainpy.analysis.Bifurcation1D receives arguments like “model”, “target_vars”, “pars_update”, and “resolutions”. Besides, one more important argument “target_pars” should be provided, which specifies the range of the target parameter in bifurcation analysis.

Here, we systematically change the parameter “Iext” from 0 to 1.5. According to the bifurcation theory, we know this simple system has a fold bifurcation when \(I=1.0\). Because at \(I=1.0\), two fixed points collide with each other into a saddle point and then disappear. Does BrainPy’s analysis toolkit brainpy.analysis.Bifurcation1D is capable of performing these analyses? Let’s make a try.

bif = bp.analysis.Bifurcation1D(
    model=int_x,
    target_vars={'x': [-10, 10]},
    target_pars={'Iext': [0., 1.5]},
    resolutions={'Iext': 0.005, 'x': 0.05}
)
bif.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/a3412772174cc10cff261287b1cb198a12f5c48d5b55381f52dc3f484e44f7dd.png

Once again, BrainPy analysis toolkit gives the right answer. It tells us how does the fixed points evolve when the parameter \(I\) is increasing.

It is worthy to note that bifurcation analysis in BrainPy is hard to find out the saddle point (when \(I=0\) for this system). This is because the saddle point at the bifurcation just exists at a moment. While the numerical method used in BrainPy analysis toolkit is almost impossible to evaluate the point exactly at the saddle. However, if the user has the minimal knowledge about the bifurcation theory, saddle point (the collision point of two fixed points) can be easily inferred from the fixed point evolution.

BrainPy’s analysis toolkit is highly useful, especially when the mathematical equations are too complex to get analytical solutions. The example please refer to the tutorial Anlysis of A Decision Making Model.

Phase plane analysis#

Phase plane analysis is one of the most important techniques for studying the behavior of nonlinear systems, since there is usually no analytical solution for a nonlinear system. BrainPy can help users to plot phase plane of 1D systems or 2D systems. Specifically, we provide brainpy.analysis.PhasePlane1D and brainpy.analysis.PhasePlane2D. It can help to plot:

  • Nullcline: The zero-growth isoclines, such as \(g(x, y)=0\) and \(g(x, y)=0\).

  • Fixed points: The equilibrium points of the system, which are located at all the nullclines intersect.

  • Vector field: The vector field of the system.

  • Limit cycles: The limit cycles.

  • Trajectories: A simulation trajectory with the given initial values.

We have talked about brainpy.analysis.PhasePlane1D in above. Now we focus on brainpy.analysis.PhasePlane2D by using a well-known neuron model FitzHugh-Nagumo model.

The FitzHugh-Nagumo model is given by:

\[\begin{split} \frac {dV} {dt} = V(1 - \frac {V^2} 3) - w + I_{ext} \\ \tau \frac {dw} {dt} = V + a - b w \end{split}\]

There are two variables \(V\) and \(w\), so this is a two-dimensional system with three parameters \(a, b\) and \(\tau\).

For the system to analyze, users can define it by using the pure brainpy.odeint or define it as a class of DynamicalSystem. For this FitzHugh-Nagumo model, we define it as a class because later we will perform simulation to verify the analysis results.

class FitzHughNagumoModel(bp.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(FitzHughNagumoModel, self).__init__()

    # parameters
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5

    # variables
    self.V = bm.Variable(bm.zeros(1))
    self.w = bm.Variable(bm.zeros(1))
    self.Iext = bm.Variable(bm.zeros(1))

    # functions
    def dV(V, t, w, Iext=0.): 
        return V - V * V * V / 3 - w + Iext
    def dw(w, t, V, a=0.7, b=0.8): 
        return (V + a - b * w) / self.tau
    self.int_V = bp.odeint(dV, method=method)
    self.int_w = bp.odeint(dw, method=method)

  def update(self, tdi):
    self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
    self.w.value = self.int_w(self.w, tdi.t, self.V, self.a, self.b, tdi.dt)
    self.Iext[:] = 0.
model = FitzHughNagumoModel()

Here we perform a phase plane analysis with parameters \(a=0.7, b=0.8, \tau=12.5\), and input \(I_{ext} = 0.8\).

pp = bp.analysis.PhasePlane2D(
  model,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  pars_update={'Iext': 0.8}, 
  resolutions={'V': 0.01, 'w': 0.01},
)
# By defaut, nullclines will be plotted as points, 
# while we can set the plot style as the line
pp.plot_nullcline(x_style={'fmt': '-'}, y_style={'fmt': '-'})

# Vector field can plotted as two ways:
# - plot_method="streamplot" (default)
# - plot_method="quiver"
pp.plot_vector_field()

# There are many ways to search fixed points. 
# By default, it will use the nullcline points of the first 
# variable ("V") as the initial points to perform fixed point searching
pp.plot_fixed_point()

# Trajectory plotting receives the setting of the initial points.
# There may be multiple trajectories, therefore the initial points 
# should be provived as a list/tuple/numpy.ndarray/Array
pp.plot_trajectory({'V': [-2.8], 'w': [-1.8]}, duration=100.)

# show the phase plane figure
pp.show_figure()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 866 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 V=-0.2729223248464073, w=0.5338542697673022 is a unstable node.
I am plotting the trajectory ...
_images/b9005b87b02892ff1d1ebf053d7c1142b3578e0b7d231c1652692596dd125534.png

We can see an unstable-node at the point (\(V=-0.27, w=0.53\)) inside a limit cycle.

We can run a simulation with the same parameters and initial values to verify the periodic activity that correspond to the limit cycle.

runner = bp.DSRunner(model, monitors=['V', 'w'], inputs=['Iext', 0.8])
runner.run(100.)

bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w', show=True)
_images/9fa17c245ebad3a0e015f79c57a15a0706bee147d165fed5bd2561ff32351434.png

Understanding settings#

There are several key settings needed to understand.

resolutions#

resolutions is one of the most important parameters in PhasePlane and Bifurcation analysis toolkits of BrainPy. It is very important because it has a profound impact on the efficiency of model analysis.

We can set resolutions with the following ways.

  1. None. If we detect there is no resolution setting for any variable, the corresponding resolution for this variable will be \(\frac{\mathrm{max\_value} - \mathrm{min\_value}}{20}\).

  2. A float. It sets a same resolution for each target variable and parameter.

  3. A dict. Specify different resolutions for individual variable/parameter. It can be a float, or a vector with the format of Array or numpy.ndarray.

Note

It is highly recommended that users specify the resolution to specific parameters or variables by a dict rather than set a float value, which will be applied to all variables. Otherwise, the computation will occupy too much memory if the resolution is set very small. For example, if you want to set the resolution of variable x as 0.01, please use resolutions={'x': 0.01}.

Enabling set resolutions with a tensor will give the user the maximal flexibility. Usually, the numerical analysis does not work well at inflection points. Therefore, we can increase the granularity near the inflection points. For example, if there is an inflection point at \(1\), we can set the resolution with:

r1 = bm.arange(0.00, 0.95, 0.01)
r2 = bm.arange(0.95, 1.01, 0.001)
r3 = bm.arange(1.05, 1.50, 0.01)
resolution = bm.concatenate([r1, r2, r3])

Tips: For bifurcation analysis, usually we need set a small resolution for parameters, leaving the resolutions of variables as the default. Please see in the following examples.

vars and pars#

What can be set as variables *_vars or parameters *_pars (such as target_vars or target_pars) for further analysis? Actually, the variables and parameters are recognized as the same with the programming paradigm of ODE numerical integrators. Simply speaking, the arguments before t will be defined as variables, while arguments after t will be parameters.

BrainPy’s analysis toolkit only support one variable in one differential equation. It cannot analyze the joint differential equation in which multiple variables are defined in the same function.

Moreover, the low-dimensional analyzers in BrainPy cannot analyze dynamical system depends on time \(t\).

Bifurcation analysis#

Nonlinear dynamical systems are characterized by its parameters. When the parameter changes, the system’s behavior will change qualitatively. Therefore, we take care of how the system changes with the smooth change of parameters.

Codimension 1 bifurcation analysis

We will first see the codimension 1 bifurcation analysis of the model. For example, we vary the input \(I_{ext}\) between 0 and 1 and see how the system change its stability.

analyzer = bp.analysis.Bifurcation2D(
  model,
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  target_pars={'Iext': [0., 1.]},
  resolutions={'Iext': 0.002},
)

# "num_rank" specifies the number of initial poinits for
# fixed point optimization under a set of parameters
analyzer.plot_bifurcation(num_rank=10)

# show figure
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 5000 candidates
I am trying to filter out duplicate fixed points ...
	Found 500 fixed points.
_images/d599ee51b537505988581b46187aaa60b369fa3d076492d26a0077279e9fb9d2.png _images/a2a291fce0278a831b56c74dfc55d9aba8345c950f993861ed9b4d301b352f1d.png

Codimension 2 bifurcation analysis

We simulaneously change \(I_{ext}\) and parameter \(a\).

analyzer = bp.analysis.Bifurcation2D(
    model,
    target_vars=dict(V=[-3, 3], w=[-3., 3.]),
    target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
    resolutions={'a': 0.01, 'Iext': 0.01},
)
analyzer.plot_bifurcation(num_rank=10, tol_aux=1e-9)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 50000 candidates
I am trying to filter out duplicate fixed points ...
	Found 4997 fixed points.
_images/9e4e9a5083e2a4507f6fb8582db24dca5d07dae3acee8df4dfc5e875ddf1c8a8.png _images/8d911eaaf36f04b36b835370335f78b9227b9f3a67207079ab447224a11788ed.png

Fast-slow system bifurcation#

BrainPy also provides a tool for fast-slow system bifurcation analysis by using brainpy.analysis.FastSlow1D and brainpy.analysis.FastSlow2D. This method is proposed by John Rinzel [1, 2, 3]. (J Rinzel, 1985, 1986, 1987) proposed that in a fast-slow dynamical system, we can treat the slow variables as the bifurcation parameters, and then study how the different value of slow variables affect the bifurcation of the fast sub-system.

Fast-slow bifurcation methods are very useful in the bursting neuron analysis. I will illustrate this by using the Hindmarsh-Rose model. The Hindmarsh–Rose model of neuronal activity is aimed to study the spiking-bursting behavior of the membrane potential observed in experiments made with a single neuron. Its dynamics are governed by:

\[\begin{split} \begin{aligned} \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I\\ \frac{d y}{d t} &= c - d V^2 - y\\ \frac{d z}{d t} &= r (s (V - V_{rest}) - z) \end{aligned} \end{split}\]

First, let’s define the Hindmarsh–Rose model with BrainPy.

a = 1.
b = 3.
c = 1.
d = 5.
s = 4.
x_r = -1.6
r = 0.001
Vth = 1.9


@bp.odeint
def int_x(x, t, y, z, Isyn):
    return y - a * x ** 3 + b * x * x - z + Isyn

@bp.odeint
def int_y(y, t, x):
    return c - d * x * x - y

@bp.odeint
def int_z(z, t, x):
    return r * (s * (x - x_r) - z)

We now can start to analysis the underlying bifurcation mechanism.

analyzer = bp.analysis.FastSlow2D(
  [int_x, int_y, int_z],
  fast_vars={'x': [-3, 3], 'y': [-10., 5.]},
  slow_vars={'z': [-5., 5.]},
  pars_update={'Isyn': 0.5},
  resolutions={'z': 0.01}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 20000 candidates
I am trying to filter out duplicate fixed points ...
	Found 1156 fixed points.
_images/e0cd936de4f141742dd686c912621cf15b183c0c925cb020beddd89cc590cc0e.png _images/73bc6355b217983aaac5270ad9c1833c2d45d7ffea7558bb37693320fb495439.png

References#

[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.

[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.

[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.

High-dimensional Analyzers#

@Chaoming Wang

It’s hard to analyze high-dimensional systems. However, we have to analyze high-dimensional systems.

Here, based on numerical optimization methods, BrainPy provides brainpy.analysis.SlowPointFinder to help users find slow points (or fixed points) [1] for your high-dimensional dynamical systems.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.3.0'

What are slow points?#

For the given system,

\[ \dot{x} = f(x), \]

we wish to find values \(x^∗\) around which the system is approximately linear. Using Taylor series expansion, we have

\[ f(x^* + \delta x) = f(x^*) + f'(x^*)\delta x + 1/2 \delta x f''(x^*) \delta x + \cdots \]

We want the first derivative term (i.e., the linear term) to be dominant, which means \(f(x^*) = 0\) or \(f(x^*) \approx 0\).

  • For \(f(x^*) \approx 0\) which is nonzero but small, we call the point \(x^*\) a slow point.

  • More specially, if \(f(x^*) = 0\), \(x^*\) is a fixed point.

How to find slow points?#

In order to find slow points, we can first define an auxiliary scalar function for your continous system \(\dot{x} = f(x)\),

\[ p(x) = |f(x)|^2. \]

Or, if your system is discrete \(x_n = f(x_{n-1})\), the auxiliary scalar function can be defined as

\[ p(x) = |x - f(x)|^2. \]

If \(x^*\) is a slow point, \(p(x^*) \to 0\).

Then, by minimizing the scalar function \(p(x)\), we can get the candidate points for slow points and for further linearization. For the linear system, it’s stability is evaluated by the eigenvalues of Jacobian matrix.

Here, BrainPy provides brainpy.analysis.SlowPointFinder. It receives f_cell to specify the target function/object to analyze.

If the provided f_cell is a function, SlowPointFinder can supports to specify:

  • f_type: the type of the function (it can be “continuous” or “discrete”).

  • f_loss: the loss function to minimize the optimization error.

  • args: extra arguments passed into the function when performing fixed point optimization.

If the provided f_cell is an instance of DynamicalSystem, SlowPointFinder can supports to specify:

  • f_loss: the loss function to minimize the optimization error.

  • args: extra arguments passed into the defined update() function when performing fixed point optimization.

  • inputs and fun_inputs: inputs to this dynamical system. Similar to the inputs of DSRunner and DSTrainer.

  • target_vars: the selected variables which are used to optimize fixed points. Other variables like “input” and “spike” can be ignored.

Then, brainpy.analysis.SlowPointFinder can help you:

  • optimize to find the fixed/slow points with gradient descent algorithms (find_fps_with_gd_method()) or nonlinear optimization solver (find_fps_with_opt_solver())

  • exclude any fixed points whose losses are above threshold: filter_loss()

  • exclude any non-unique fixed points according to a tolerance: keep_unique()

  • exclude any far-away “outlier” fixed points: exclude_outliers()

  • computing the jacobian matrix for the given fixed/slow points: compute_jacobians()

Example 1: Decision Making Model#

brainpy.analysis.SlowPointFinder is aimed to find slow/fixed points of high-dimensional systems. Of course, it can optimize to find fixed points of low-dimensional systems. We take the 2D decision-making system as an example.

# parameters

gamma = 0.641  # Saturation factor for gating variable
tau = 0.06  # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154

JE = 0.3725  # self-coupling strength [nA]
JI = -0.1137  # cross-coupling strength [nA]
JAext = 0.00117  # Stimulus input strength [nA]

mu = 20.  # Stimulus firing rate [spikes/sec]
coh = 0.5  # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
  I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
  r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
  return - s1 / tau + (1. - s1) * gamma * r1

@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
  I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
  r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
  return - s2 / tau + (1. - s2) * gamma * r2

def step(s):
    ds1 = int_s1.f(s[0], 0., s[1])
    ds2 = int_s2.f(s[1], 0., s[0])
    return bm.asarray([ds1, ds2])

We first use brainpy.analysis.PhasePlane2D to get the standard answer.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    resolutions=0.001,
)
analyzer.plot_fixed_point(select_candidates='aux_rank', with_plot=False)
I am searching fixed points ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 100 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.2827633321285248, s2=0.40635180473327637 is a saddle node.
	#2 s1=0.013946513645350933, s2=0.6573889851570129 is a stable node.
	#3 s1=0.7004518508911133, s2=0.004864312242716551 is a stable node.

Then, let’s check whether the high-dimensional analyzer also works.

finder = bp.analysis.SlowPointFinder(f_cell=step, f_type="continuous")
finder.find_fps_with_gd_method(
    candidates=bm.random.random((1000, 2)),
    tolerance=1e-5,
    num_batch=200,
    optimizer=bp.optimizers.Adam(bp.optimizers.ExponentialDecay(0.01, 1, 0.9999))
)
finder.filter_loss(1e-5)
finder.keep_unique()
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.24 sec, Training loss 0.0510474481
    Batches 201-400 in 0.25 sec, Training loss 0.0046035680
    Batches 401-600 in 0.32 sec, Training loss 0.0007384720
    Batches 601-800 in 0.27 sec, Training loss 0.0001601687
    Batches 801-1000 in 0.25 sec, Training loss 0.0000381663
    Batches 1001-1200 in 0.25 sec, Training loss 0.0000088441
    Stop optimization as mean training loss 0.0000088441 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
    Kept 934/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
    Kept 3/934 unique fixed points with uniqueness tolerance 0.025.
finder.fixed_points
array([[0.28276306, 0.40635154],
       [0.7004519 , 0.00486429],
       [0.01394659, 0.6573889 ]], dtype=float32)

Yeah, the fixed points found by brainpy.analysis.PhasePlane2D and brainpy.analysis.SlowPointFinder are nearly the same.

Example 2: Continuous-attractor Neural Network#

Continuous-attractor neural network [2] proposed by Si Wu is a special model which has a line of attractors.

class CANN1D(bp.dyn.NeuGroup):
  def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., z_min=-bm.pi, z_max=bm.pi):
    super(CANN1D, self).__init__(size=num)

    # parameters
    self.tau = tau  # The synaptic time constant
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
    self.rho = num / self.z_range  # The neural density
    self.dx = self.z_range / num  # The stimulus density

    # variables
    self.u = bm.Variable(bm.zeros(num))
    self.input = bm.Variable(bm.zeros(num))

    # The connection matrix
    self.conn_mat = self.make_conn(self.x)

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, u, t, Iext):
    r1 = bm.square(u)
    r2 = 1.0 + self.k * bm.sum(r1)
    r = r1 / r2
    Irec = bm.dot(self.conn_mat, r)
    du = (-u + Irec + Iext) / self.tau
    return du

  def dist(self, d):
    d = bm.remainder(d, self.z_range)
    d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
    return d

  def make_conn(self, x):
    assert bm.ndim(x) == 1
    x_left = bm.reshape(x, (-1, 1))
    x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
    d = self.dist(x_left - x_right)
    Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
    return Jxx

  def get_stimulus_by_pos(self, pos):
    return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

  def update(self, tdi):
    self.u.value = self.integral(self.u, tdi.t, self.input, tdi.dt)
    self.input[:] = 0.

  def cell(self, u):
    return self.derivative(u, 0., 0.)
cann = CANN1D(num=512, k=0.1, A=30)

The following code demonstrates how to use SlowPointFinder to find fixed points of a continuous attractor neural network.

# initialize an instance of slow point finder
finder = bp.analysis.SlowPointFinder(
    f_cell=cann,
    target_vars={'u': cann.u},
    dt=1.,
)

# we can initialize our candidate points with noisy bumps.
candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1)))
candidates += bm.random.normal(0., 0.01, candidates.shape)

# optimize to find fixed points
finder.find_fps_with_opt_solver({'u': candidates})
finder.filter_loss(1e-6)
finder.keep_unique()
Optimizing with BFGS to find fixed points:
    Found 629 fixed points from 629 initial points.
Excluding fixed points with squared speed above tolerance 1e-06:
    Kept 357/629 fixed points with tolerance under 1e-06.
Excluding non-unique fixed points:
    Kept 357/357 unique fixed points with uniqueness tolerance 0.025.

The found fixed points are a series of attractor. We can visualize this line of attractors on a 2D space.

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points['u'])
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()
_images/2845c740bf2ec00d465748c8a480a5e3d621e881a2de16d3caf2318b2216f49c.png

These fixed points can also be plotted on the feature space. In the following, we plot the selected points.

def visualize_fixed_points(fps, plot_ids=(0,), xs=None):
  for i in plot_ids:
    if xs is None:
      plt.plot(fps[i], label=f'FP-{i}')
    else:
      plt.plot(xs, fps[i], label=f'FP-{i}')
  plt.legend()
  plt.xlabel('Feature')
  plt.ylabel('Bump activity')
  plt.show()
visualize_fixed_points(finder.fixed_points['u'],
                       plot_ids=(10, 20, 30, 40, 50, 60, 70, 80),
                       xs=cann.x)
_images/208363168444db596ded0c6e4ecfd57996aa8172430fb78b4852f615c41e397f.png

Let’s find the linear part or the Jacobian matrix around the fixed points. We decompose Jacobian matrix and then visualize its stability.

from jax import tree_map

# select the first ten fixed points
fps = tree_map(lambda a: a[:10], finder._fixed_points)

# compute jacobian and visualize the decomposed jacobian matrix
J = finder.compute_jacobians(fps, plot=True, num_col=2)
_images/64c17fc6981b4b5b7d7dbe58c30e36aaf29fc46de86d7aa6201d9a50362483f9.png

More examples of dynamics analysis, for example, analyzing the fixed points in a recurrent neural network, please see BrainPy Examples.

References#

[1] Sussillo, D. , and O. Barak . “Opening the Black Box: Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks.” Neural computation 25.3(2013):626-649.

[2] Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.

Analysis of a Decision-making Model#

@Chaoming Wang

In this section, we are going to use the low-dimensional analyzers to make phase plane and bifurcation analysis for the decision-making model proposed by (Wong & Wang) [1].

Decision making model#

This model considers two excitatory neural assemblies, populations 1 and 2 , that compete with each other through a shared pool of inhibitory neurons. In our analysis, we use the following model equations.

Let \(r_1\) and \(r_2\) be firing rates of E and I populations, and the total synaptic input current \(I_i\) and the resulting firing rate \(r_i\) of the neural population \(i\) obey the following input-output relationship (\(F - I\) curve):

\[ r_i = F(I_i) = \frac{aI_i - b}{1-\exp(-d(a I_i - b))} \]

which captures the current-frequency function of a leaky integrate-and-fire neuron. The parameter values are \(a\) = 270 Hz/nA, \(b\) = 108 Hz, \(d\) = 0.154 sec.

Assume that the synaptic drive variables’ \(S_1\) and \(S_2\) obey

\[\begin{split} \frac{dS_1}{dt} = F(I_1)\,\gamma(1-S_1)-S_1/\tau_s\\ \frac{dS_2}{dt} = F(I_2)\,\gamma(1-S_2)-S_2/\tau_s \end{split}\]

where \(\gamma\) = 0.641. The net current into each population is given by

\[\begin{split} I_1 = J_E S_1 + J_I S_2 + I_{b1} + J_{ext}\mu_1 \\ I_2 = J_E S_2 + J_I S_1 +I_{b2} +J_{ext}\mu_2. \end{split}\]

The synaptic time constant is \(\tau_s\) = 100 ms (NMDA time consant). The synaptic coupling strengths are \(J_E\) = 0.2609 nA, \(J_I\) = -0.0497 nA, and \(J_{ext}\) = 0.00052 nA. Stimulus-selective inputs to populations 1 and 2 are governed by unitless parameters \(\mu_1\) and \(\mu_2\), respectively.

For the decision-making paradigm, the input rates \(\mu_1\) and \(\mu_2\) are determined by the stimulus coherence \(c'\) which ranges between 0 (0%) and 1 (100%):

\[\begin{split} \mu_1 =\mu_0(1+c')\\ \mu_2 =\mu_0(1-c') \end{split}\]
import brainpy as bp
import brainpy.math as bm

bp.math.enable_x64()
# bp.math.set_platform('cpu')
bp.__version__
'2.3.0'

Parameters#

gamma = 0.641  # Saturation factor for gating variable
tau = 0.1  # Synaptic time constant [sec]
a = 270.  #  Hz/nA
b = 108.  # Hz
d = 0.154  # sec

I0 = 0.3255  # background current [nA]
JE = 0.2609  # self-coupling strength [nA]
JI = -0.0497  # cross-coupling strength [nA]
JAext = 0.00052  # Stimulus input strength [nA]
Ib = 0.3255  # The background input

Model implementation#

@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
    I1 = JE * s1 + JI * s2 + Ib + JAext * mu * (1. + coh)
    r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
    return - s1 / tau + (1. - s1) * gamma * r1

@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
    I2 = JE * s2 + JI * s1 + Ib + JAext * mu * (1. - coh)
    r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
    return - s2 / tau + (1. - s2) * gamma * r2

Phase plane analysis#

The advantage of the reduced model is that we can understand what dynamical behaviors the model generate for a particular parmeter set using phase-plane analysis and the explore how this behavior changed when the model parameters are varied (bifurcation analysis).

To this end, we will use brainpy.analysis module.

We construct the phase portraits of the reduced model for different stimulus inputs (see Figure 4 and Figure 5 in (Wong & Wang, 2006) [1]).

No stimulus: \(\mu_0 =0\) Hz. In the absence of a stimulus, the two nullclines intersect with each other five times, producing five steady states, of which three are stable (attractors) and two are unstable

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 0.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 5 fixed points.
	#1 s1=0.5669871605297269, s2=0.031891419715715866 is a stable node.
	#2 s1=0.31384492489136057, s2=0.05578533347184539 is a saddle node.
	#3 s1=0.1026514458219984, s2=0.10265095098914433 is a stable node.
	#4 s1=0.05578534267632889, s2=0.3138449310808786 is a saddle node.
	#5 s1=0.03189144636489119, s2=0.5669870352865433 is a stable node.
_images/a1fc7bca7387aa87c944e1737fdbc117b24fdb8a566e757f579d0a5e2e259ce5.png

Symmetric stimulus: \(\mu_0=30\) Hz, \(c'=0\). When a stimulus is applied, the phase space of the model is reconfigured. The spontaneous state vanishes. At the same time, a saddle-type unstable steady state is created that separates the two asymmetrical attractors.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 0.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.658694232143127, s2=0.05180719943991283 is a stable node.
	#2 s1=0.42445578984858384, s2=0.4244556283731401 is a saddle node.
	#3 s1=0.05180717720080605, s2=0.6586942355713474 is a stable node.
_images/66a26b8f723c739101c82f4f403d39251cd97e7b4b6b984e15052d7a0fe15a70.png

Biased stimulus: \(\mu_0=30\) Hz, \(c' = 0.14\) (14 % coherence). The phase space changes when a weak motion stimulus is presented. The phase space is no longer symmetrical: the attractor state s1 (correct choice) has a larger basin of attraction than attractor s2.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 0.14},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 3 fixed points.
	#1 s1=0.6679776124172938, s2=0.04583022226100692 is a stable node.
	#2 s1=0.3845586078985544, s2=0.4536309035289816 is a saddle node.
	#3 s1=0.059110032802350894, s2=0.6481046659437735 is a stable node.
_images/f13e9d077d2c0a4bf4f39b44a499a09130b1662770fe6cf76b14936b92c21445.png

Stimulus to one population only: \(\mu_0=30\) Hz, \(c'=1.\) (100 % coherence). When \(c'\) is sufficiently large, the saddle steady state annihilates with the less favored attractor, leaving only one choice attractor.

analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'mu': 30., 'coh': 1.},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
analyzer.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
	There are 1212 candidates
I am trying to filter out duplicate fixed points ...
	Found 1 fixed points.
	#1 s1=0.7092805209334904, s2=0.02396366304199462 is a stable node.
_images/416a78aabaf84a332d20b5389ac11f9321badc4355dccbd415fd337ff0d0a68d.png

Bifurcation analysis#

To see how the phase portrait of the system changed when we chang the stimulus current, we will generate a bifurcation diagram for the reduced model. On the bifurcation diagram the fixed points of the model are shown as a function of a changing parameter.

In the next, we generate bifurcation diagrams with the different parameters.

Fix the coherence \(c'=0\), vary the stimulus strength \(\mu_0\). See Figure 10 in (Wong & Wang, 2006) [1].

analyzer = bp.analysis.Bifurcation2D(
  model=[int_s1, int_s2],
  target_vars={'s1': [0., 1.], 's2': [0., 1.]},
  target_pars={'mu': [-30., 90.]},
  pars_update={'coh': 0.},
  resolutions={'mu': 0.2},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 30000 candidates
I am trying to filter out duplicate fixed points ...
	Found 1744 fixed points.
_images/7fa70b42ca7aa5e37d8ccc98aef5d9a36b9a655d3d13668fe37c03f098364eb9.png _images/4b600dce214264855370dfa7183a4ea78081603a37f0d9c8668d98c9574b9c79.png

Fix the stimulus strength \(\mu_0 = 30\) Hz, vary the coherence \(c'\).

analyzer = bp.analysis.Bifurcation2D(
  model=[int_s1, int_s2],
  target_vars={'s1': [0., 1.], 's2': [0., 1.]},
  target_pars={'coh': [0., 1.]},
  pars_update={'mu': 30.},
  resolutions={'coh': 0.005},
)
analyzer.plot_bifurcation(num_rank=50)
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
	There are 10000 candidates
I am trying to filter out duplicate fixed points ...
	Found 474 fixed points.
_images/3f3f33b1fa148f61506d0bf628a72f2487fbd7a387e702e603b86e9187439708.png _images/a1e7719eedca06cdccec98f49d81697cf14cbe0b32dc75481f5dd9ae20016488.png

References#

[1] Wong K-F and Wang X-J (2006). A recurrent network mechanism for time integration in perceptual decisions. J. Neurosci 26, 1314-1328.

How does low-dimensional analyzers work?#

@Chaoming Wang

As is known to us all, dynamics analysis is necessary in neurodynamics. This is because blind simulation of nonlinear systems is likely to produce few results or misleading results. BrainPy has well supports for low-dimensional systems, no matter how nonlinear your defined system is. Specifically, BrainPy provides the following methods for the analysis of low-dimensional systems:

  1. phase plane analysis;

  2. codimension 1 or codimension 2 bifurcation analysis;

  3. bifurcation analysis of the fast-slow system.

BrainPy will help you probe the dynamical mechanism of your defined systems rapidly.

import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')
bp.math.enable_x64()  # It's better to enable x64 when performing analysis
import numpy as np
import matplotlib.pyplot as plt

In this section, we provide a basic tutorial to understand how the brainpy.analysis.LowDimAnalyzer works.

Terminology#

Given the FitzHugh-Nagumo model, we define an analyzer,

class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(FitzHughNagumoModel, self).__init__()

    # parameters
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5

    # variables
    self.V = bm.Variable(bm.zeros(1))
    self.w = bm.Variable(bm.zeros(1))
    self.Iext = bm.Variable(bm.zeros(1))

    # functions
    def dV(V, t, w, Iext=0.):
        return V - V * V * V / 3 - w + Iext
    def dw(w, t, V, a=0.7, b=0.8):
        return (V + a - b * w) / self.tau
    self.int_V = bp.odeint(dV, method=method)
    self.int_w = bp.odeint(dw, method=method)

  def update(self, tdi):
    self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
    self.w.value = self.int_w(self.w, tdi.t, self.V, self.a, self.b, tdi.dt)
    self.Iext[:] = 0.
model = FitzHughNagumoModel()
analyzer = bp.analysis.PhasePlane2D(
  [model.int_V, model.int_w],
  target_vars={'V': [-3, 3], 'w': [-3., 3.]},
  resolutions={'V': 0.01, 'w': 0.01},
)

In this instance of brainpy.analysis.LowDimAnalyzer, we use the following terminologies.

  • x_var and y_var are defined by the order of the user setting. If the user sets the “target_vars” as “{‘V’: …, ‘w’: …}”, x_var and y_var will be “V” and “w” respectively. Otherwise, if “target_vars”=”{‘w’: …, ‘V’: …}”, x_var and y_var will be “w” and “V” respectively.

analyzer.x_var, analyzer.y_var
('V', 'w')
  • fx and fy are defined as differential equations of x_var and y_var respectively, i.e.,

fx is

def dV(V, t, w, Iext=0.):
    return V - V * V * V / 3 - w + Iext

fy is

def dw(w, t, V, a=0.7, b=0.8):
    return (V + a - b * w) / self.tau
analyzer.F_fx, analyzer.F_fy
(JITTransform(target=f2, 
              num_of_vars=0),
 JITTransform(target=f2, 
              num_of_vars=0))
  • int_x and int_y are defined as integral functions of the differential equations for x_var and y_var respectively.

analyzer.F_int_x, analyzer.F_int_y
(functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x00000268A2599E50>),
 functools.partial(<function std_derivative.<locals>.inner.<locals>.call at 0x00000268A2599EE0>))
  • x_by_y_in_fx and y_by_x_in_fx: They denote that x_var and y_var can be separated from each other in “fx” nullcline function. Specifically, x_by_y_in_fx or y_by_x_in_fx denotes \(x = F(y)\) or \(y = F(x)\) accoording to \(f_x=0\) equation. For example, in the above FitzHugh-Nagumo model, \(w\) can be easily represented by \(V\) when \(\mathrm{dV(V, t, w, I_{ext})} = 0\), i.e., y_by_x_in_fx is \(w= V - V ^3 / 3 + I_{ext}\).

  • Similarly, x_by_y_in_fy (\(x=F(y)\)) and y_by_x_in_fy (\(y=F(x)\)) denote x_var and y_var can be separated from each other in “fy” nullcline function. For example, in the above FitzHugh-Nagumo model, y_by_x_in_fy is \(w= \frac{V + a}{b}\), and x_by_y_in_fy is \(V= b * w - a\).

  • x_by_y_in_fx, y_by_x_in_fx, x_by_y_in_fy and y_by_x_in_fy can be set in the options argument.

Mechanism for 1D system analysis#

In order to understand the adavantages and disadvantages of BrainPy’s analysis toolkit, it is better to know the minimal mechanism how brainpy.analysis works.

The automatic model analysis in BrainPy heavily relies on numerical optimization methods, including Brent’s method and BFGS method. For example, for the above one-dimensional system (\(\frac{dx}{dt} = \mathrm{sin}(x) + I\)), after the user sets the resolution to 0.001, we will get the evaluation points according to the variable boundary [-10, 10].

bp.math.arange(-10, 10, 0.001)
Array([-10.   ,  -9.999,  -9.998, ...,   9.997,   9.998,   9.999],      dtype=float64)

Then, BrainPy filters out the candidate intervals in which the roots lie in. Specifically, it tries to find all intervals like \([x_1, x_2]\) where \(f(x_1) * f(x_2) \le 0\) for the 1D system \(\frac{dx}{dt} = f(x)\).

For example, the following two points which have opposite signs are candidate points we want.

def plot_interval(x0, x1, f):
    xs = np.linspace(x0, x1, 100)
    plt.plot(xs, f(xs))
    plt.scatter([x0, x1], f(np.asarray([x0, x1])), edgecolors='r')
    plt.axhline(0)
    plt.show()
plot_interval(-0.001, 0.001, lambda x: np.sin(x))
_images/8b28f1dfe559b1adc3bd8f40b72bc6b8bf573177c343d82532c2cfaa92219dc1.png

According to the intermediate value theorem, there must be a solution between \(x_1\) and \(x_2\) when \(f(x_1) * f(x_2) \le 0\).

Based on these candidate intervals, BrainPy uses Brent’s method to find roots \(f(x) = 0\). Further, after obtain the value of the root, BrainPy uses automatic differentiation to evaluate the stability of each root solution.

Overall, BrainPy’s analysis toolkit shows significant advantages and disadvantages.

Pros: BrainPy uses numerical methods to find roots and evaluate their stabilities, it does not case about how complex your function is. Therefore, it can apply to general problems, including any 1D and 2D dynamical systems, and some part of low-dimensional (\(\ge 3\)) dynamical systems (see later sections). Especially, BrainPy’s analysis toolkit is highly useful when the mathematical equations are too complex to get analytical solutions (the example please refer to the tutorial Anlysis of A Decision Making Model).

Cons: However, numerical methods used in BrainPy are hard to find fixed points only exist at a moment. Moreover, when resolution is small, there will be large amount of calculating. Users should pay attention to designing suitable resolution settings.

Mechanism for 2D system analysis#

plot_vector_field()

Plotting vector field is simple. We just need to evaluate the values of each differential equation.

plot_nullcline()

Nullclines are evaluated through the Brent’s methods. In order to get all \((x, y)\) values that satisfy fx=0 (i.e., \(f_x(x, y) = 0\)), we first fix \(y=y_0\), then apply Brent optimization to get all \(x'\) that satisfy \(f_x(x', y_0) = 0\) (alternatively, we can fix \(x\) then optimize \(y\)). Therefore, we will perform Brent optimization many times, because we will iterate over all \(y\) value according to the resolution setting.

plot_fixed_points()

The fixed point finding in BrainPy relies on BFGS method. First, we define an auxiliary function \(L(x, t)\):

\[ L(x, y) = f_x^2(x, y) + f_y^2(x, y). \]

\(L(x, t)\) is always bigger than 0. We use BFGS optimization to get all local minima. Finally, we filter out the minima whose losses are smaller than \(1e^{-8}\), and we choose them as fixed points.

For this method, how to choose the initial points to perform optimization is the challege, especially when the parameter resolutions are small. Generally, there are four methods provided in BrainPy.

  • fx-nullcline: Choose the points in “fx” nullcline as the initial points for optimization.

  • fy-nullcline: Choose the points in “fy” nullcline as the initial points for optimization.

  • nullclines: Choose both the points in “fx” nullcline and “fy” nullcline as the initial points for optimization.

  • aux_rank: For a given set of parameters, we evaluate loss function at each point according to the resolution setting. Then we choose the first num_rank (default is 100) points which have the smallest losses.

However, if users provide one of functions of x_by_y_in_fx, y_by_x_in_fx, x_by_y_in_fy and y_by_x_in_fy. Things will become very simple, because we can change the 2D system as a 1D system, then we only need to optimzie the fixed points by using our favoriate Brent optimization.

For the given FitzHugh-Nagumo model, we can set

analyzer = bp.analysis.Bifurcation2D(
    model,
    target_vars=dict(V=[-3, 3], w=[-3., 3.]),
    target_pars=dict(a=[0.5, 1.], Iext=[0., 1.]),
    resolutions={'a': 0.01, 'Iext': 0.01},
    options={bp.analysis.C.y_by_x_in_fy: (lambda V, a=0.7, b=0.8: (V + a) / b)}
)
analyzer.plot_bifurcation()
analyzer.show_figure()
I am making bifurcation analysis ...
I am trying to find fixed points by brentq optimization ...
I am trying to filter out duplicate fixed points ...
	Found 5000 fixed points.
_images/0763fe241e6b578d643d27f6afceea43a46eb52a0ca75c9ca42fa2700be759c4.png _images/ed3e796840cce21eecb78676157c37cccf0caa8f934206d629029f0ed86414c4.png

References#

[1] Rinzel, John. “Bursting oscillations in an excitable membrane model.” In Ordinary and partial differential equations, pp. 304-316. Springer, Berlin, Heidelberg, 1985.

[2] Rinzel, John , and Y. S. Lee . On Different Mechanisms for Membrane Potential Bursting. Nonlinear Oscillations in Biology and Chemistry. Springer Berlin Heidelberg, 1986.

[3] Rinzel, John. “A formal classification of bursting mechanisms in excitable systems.” In Mathematical topics in population biology, morphogenesis and neurosciences, pp. 267-281. Springer, Berlin, Heidelberg, 1987.

Interoperation with other JAX frameworks#

BrainPy is designed to be easily interoperated with other JAX frameworks.

import jax
import brainpy as bp
# math library of BrainPy, JAX, NumPy
import brainpy.math as bm
import jax.numpy as jnp
import numpy as np

1. data are exchangeable among different frameworks.#

This can be realized because Array can be direactly converted to JAX ndarray or NumPy ndarray.

Convert a Array into a JAX ndarray.

b = bm.random.randint(10, size=5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Array.value is a JAX's DeviceArray
b.value
DeviceArray([9, 9, 0, 4, 7], dtype=int32)

Convert a Array into a numpy ndarray.

# Array can be easily converted to a numpy ndarray
np.asarray(b)
array([9, 9, 0, 4, 7])

Convert a numpy ndarray into a Array.

bm.asarray(np.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Convert a JAX ndarray into a Array.

bm.asarray(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)
bm.Array(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)

2. transformations in brainpy.math also work on functions.#

APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.

import optax
# First create several useful functions.

network = jax.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))
optimizer = optax.adam(learning_rate=1e-1)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

@bm.jit
def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
# Generate some data

bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
opt_state = optimizer.init(params)
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'

3. other JAX frameworks can be integrated into a BrainPy program.#

In this example, we use the Flax, a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy’s syntax.

Here, we first use flax to define a CNN network.

from flax import linen as nn

class CNN(nn.Module):
  """A CNN model implemented by using Flax."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    return x

Then, we define an RNN model by using our BrainPy interface.

from jax.tree_util import tree_flatten, tree_map, tree_unflatten

class Network(bp.dyn.DynamicalSystem):
  """A network model implemented by BrainPy"""

  def __init__(self):
    super(Network, self).__init__()

    # cnn and its parameters
    self.cnn = CNN()
    rng = bm.random.DEFAULT.split_key()
    params = self.cnn.init(rng, jnp.ones([1, 4, 28, 1]))['params']
    leaves, self.tree = tree_flatten(params)
    self.implicit_vars.update(tree_map(bm.TrainVar, leaves))

    # rnn
    self.rnn = bp.layers.GRU(256, 100)

    # readout
    self.linear = bp.layers.Dense(100, 10)

  def update(self, sha, x):
    params = tree_unflatten(self.tree, [v.value for v in self.implicit_vars.values()])
    x = self.cnn.apply({'params': params}, bm.as_jax(x))
    x = self.rnn(sha, x)
    x = self.linear(sha, x)
    return x

We initialize the network, optimizer, loss function, and BP trainer.

net = Network()
opt = bp.optim.Momentum(0.1)

def loss_func(predictions, targets):
  logits = bm.max(predictions, axis=1)
  loss = bp.losses.cross_entropy_loss(logits, targets)
  accuracy = bm.mean(bm.argmax(logits, -1) == targets)
  return loss, {'accuracy': accuracy}

trainer = bp.train.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)

We get the MNIST dataset.

train_dataset = bp.datasets.MNIST(r'D:\data\mnist', train=True, download=True)
X = train_dataset.data.reshape((-1, 7, 4, 28, 1)) / 255
Y = train_dataset.targets

Finally, train our defined model by using BPTT.fit() function.

trainer.fit([X, Y], batch_size=256, num_epoch=10)
Train 100 steps, use 32.5824 s, train loss 0.96465, accuracy 0.66015625
Train 200 steps, use 30.9035 s, train loss 0.38974, accuracy 0.89453125
Train 300 steps, use 33.1075 s, train loss 0.31525, accuracy 0.890625
Train 400 steps, use 31.4062 s, train loss 0.23846, accuracy 0.91015625
Train 500 steps, use 32.3371 s, train loss 0.21995, accuracy 0.9296875
Train 600 steps, use 32.5692 s, train loss 0.20885, accuracy 0.92578125
Train 700 steps, use 33.0139 s, train loss 0.24748, accuracy 0.90625
Train 800 steps, use 31.9635 s, train loss 0.14563, accuracy 0.953125
Train 900 steps, use 31.8845 s, train loss 0.17017, accuracy 0.94140625
Train 1000 steps, use 32.0537 s, train loss 0.09413, accuracy 0.95703125
Train 1100 steps, use 32.3714 s, train loss 0.06015, accuracy 0.984375
Train 1200 steps, use 31.6957 s, train loss 0.12061, accuracy 0.94921875
Train 1300 steps, use 31.8346 s, train loss 0.13908, accuracy 0.953125
Train 1400 steps, use 31.5252 s, train loss 0.10718, accuracy 0.953125
Train 1500 steps, use 31.7274 s, train loss 0.07869, accuracy 0.96875
Train 1600 steps, use 32.3928 s, train loss 0.08295, accuracy 0.96875
Train 1700 steps, use 31.7718 s, train loss 0.07569, accuracy 0.96484375
Train 1800 steps, use 31.9243 s, train loss 0.08607, accuracy 0.9609375
Train 1900 steps, use 32.2454 s, train loss 0.04332, accuracy 0.984375
Train 2000 steps, use 31.6231 s, train loss 0.02369, accuracy 0.9921875
Train 2100 steps, use 31.7800 s, train loss 0.03862, accuracy 0.9765625
Train 2200 steps, use 31.5431 s, train loss 0.01871, accuracy 0.9921875
Train 2300 steps, use 32.1064 s, train loss 0.03255, accuracy 0.9921875

Numerical Solvers for Ordinary Differential Equations#

@Chaoming Wang @Xiaoyu Chen

Brain modeling toolkit provided in BrainPy is focused on differential equations. How to solve differential equations is the essence of the neurodynamics simulation. The exact algebraic solutions are only available for low-order differential equations. For the coupled high-dimensional non-linear brain dynamical systems, we need to resort to numerical methods for solving such differential equations.

This section will illustrate how to define ordinary differential quations (ODEs) and how to define the numerical integration methods for ODEs in BrainPy.

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

bm.set_platform('cpu')

bp.__version__
'2.3.0'

How to define ODE functions?#

BrainPy provides a convenient and intuitive way to define ODE systems. For the ODEs

\[\begin{split} {dx \over dt} = f_1(x, t, y, p_1)\\ {dy \over dt} = g_1(y, t, x, p_2) \end{split}\]

we can define them in a Python function:

def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

where t denotes the current time, x and y passed before t denote the dynamical variables, and p1 and p2 after t denote the parameters needed in this system. In the function body, the derivative f1 and g1 can be customized by the user’s need. Finally, the corresponding derivatives dx and dy are returned in the same order as that of the variables in the function arguments.

For each variabl, it can be a scalar (var_type = bp.integrators.SCALAR_VAR), a vector/matrix (var_type = bp.integrators.POP_VAR), or a system (var_type = bp.integrators.SYSTEM_VAR). The “system” means that the argument x denotes an array of variables. Take the above example as the demonstration again, we can redefine it as:

def diff(xy, t, p1, p2):
    x, y = xy
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return bm.array([dx, dy])

How to define the numerical integration for ODEs?#

After the definition of ODE functions, it is very easy to define the numerical integration for these functions. We just need to put a decorator bp.odeint above the ODE function.

@bp.odeint
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

After wrapping it by bp.odeint, the function becomes an instance of ODEintegrator.

isinstance(diff, bp.ode.ODEIntegrator)
True

bp.odeint receives several arguments:

  • “method”: A string, used to specify the numerical methods to integrate the ODE functions. The default method is Euler.

diff
<brainpy.integrators.ode.explicit_rk.Euler at 0x15abadeedc0>
  • “dt”: A float, used to set the default numerical precision. The default “dt” is 0.1.

diff.dt
0.1
  • “show_code”: bool, to indicate whether to show the numerical integration code. Let’s take Euler method and RK4 method as the illustrated examples.

@bp.odeint(method='euler', show_code=True, dt=0.01)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode1_diff(x, y, t, p1, p2, dt=0.01):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  x_new = x + dx_k1 * dt * 1
  y_new = y + dy_k1 * dt * 1
  return x_new, y_new

{'f': <function diff at 0x0000015ABADFBF70>}
<brainpy.integrators.ode.explicit_rk.Euler at 0x15abae03550>
@bp.odeint(method='rk4', show_code=True, dt=0.1)
def diff(x, y, t, p1, p2):
    dx = f1(x, t, y, p1)
    dy = g1(y, t, x, p2)
    return dx, dy

diff
def brainpy_itg_of_ode2_diff(x, y, t, p1, p2, dt=0.1):
  dx_k1, dy_k1 = f(x, y, t, p1, p2)
  k2_x_arg = x + dt * dx_k1 * 0.5
  k2_y_arg = y + dt * dy_k1 * 0.5
  k2_t_arg = t + dt * 0.5
  dx_k2, dy_k2 = f(k2_x_arg, k2_y_arg, k2_t_arg, p1, p2)
  k3_x_arg = x + dt * dx_k2 * 0.5
  k3_y_arg = y + dt * dy_k2 * 0.5
  k3_t_arg = t + dt * 0.5
  dx_k3, dy_k3 = f(k3_x_arg, k3_y_arg, k3_t_arg, p1, p2)
  k4_x_arg = x + dt * dx_k3
  k4_y_arg = y + dt * dy_k3
  k4_t_arg = t + dt
  dx_k4, dy_k4 = f(k4_x_arg, k4_y_arg, k4_t_arg, p1, p2)
  x_new = x + dx_k1 * dt * 1/6 + dx_k2 * dt * 1/3 + dx_k3 * dt * 1/3 + dx_k4 * dt * 1/6
  y_new = y + dy_k1 * dt * 1/6 + dy_k2 * dt * 1/3 + dy_k3 * dt * 1/3 + dy_k4 * dt * 1/6
  return x_new, y_new

{'f': <function diff at 0x0000015ABAE023A0>}
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abae03d30>

Two Examples#

Example 1: FitzHugh–Nagumo model#

Now, let’s take the well known FitzHugh–Nagumo model as an exmaple to illustrate how to define ODE solvers for brain modeling. The FitzHugh–Nagumo model (FHN) model has two dynamical variables, which are governed by the following equations:

\[\begin{split} \begin{split} \tau {\dot {w}}&=v+a-bw\\ {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+I_{\rm {ext}} \end{split} \end{split}\]

For this FHN model, we can code it in BrainPy like this:

@bp.odeint(dt=0.01)
def integral(V, w, t, Iext, a, b, tau):
    dw = (V + a - b * w) / tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

After defining the numerical solver, the solution of the ODE system in the given times can be easily solved. For example, for the given parameters,

a = 0.7;   b = 0.8;   tau = 12.5;   Iext = 1.

the solution of the FHN model between 0 and 100 ms can be approximated by

hist_times = bm.arange(0, 100, 0.01)
hist_V = []
V, w = 0., 0.
for t in hist_times:
    V, w = integral(V, w, t, Iext, a, b, tau)
    hist_V.append(V)

plt.plot(hist_times, hist_V)
plt.show()
_images/52d96ce30b228c8a4b88d88e5d52a49e8b34aeb17bb3f92959c9ea0fcbb7f584.png

This manual loop in Python code is usually slow. In BrainPy, we provide a structural runner for integrators: brainpy.integrators.IntegratorRunner, which can benefit from the JIT compilation.

runner = bp.IntegratorRunner(
    integral,
    monitors=['V'],
    inits=dict(V=0., w=0.),
    args=dict(a=a, b=b, tau=tau, Iext=Iext),
    dt=0.01
)
runner.run(100.)

plt.plot(runner.mon.ts, runner.mon.V)
plt.show()
_images/ed252b7e5a3c5db60df5dc5a704e9fcaff6c46f30b791c0c2f38ed6ad57758b0.png

Example 2: Hodgkin–Huxley model#

Another more complex example is the classical Hodgkin–Huxley neuron model. In HH model, four dynamical variables (V, m, n, h) are used for modeling the initiation and propagation of the action potential. Specifically, they are governed by the following equations:

\[\begin{split} \begin{aligned} C_{m} \frac{d V}{d t} &=-\bar{g}_{\mathrm{K}} n^{4}\left(V-V_{K}\right)- \bar{g}_{\mathrm{Na}} m^{3} h\left(V-V_{N a}\right)-\bar{g}_{l}\left(V-V_{l}\right)+I_{s y n} \\ \frac{d m}{d t} &=\alpha_{m}(V)(1-m)-\beta_{m}(V) m \\ \frac{d h}{d t} &=\alpha_{h}(V)(1-h)-\beta_{h}(V) h \\ \frac{d n}{d t} &=\alpha_{n}(V)(1-n)-\beta_{n}(V) n \end{aligned} \end{split}\]

In BrainPy, such dynamical system can be coded as:

@bp.odeint(method='rk4', dt=0.01)
def integral(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m

    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h

    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n

    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C

    return dVdt, dmdt, dhdt, dndt

Same as the FHN model, we can also integrate the HH model in the given parameters and time interval:

Iext = 10.;   ENa = 50.;   EK = -77.;   EL = -54.387
C = 1.0;      gNa = 120.;  gK = 36.;    gL = 0.03
runner = bp.IntegratorRunner(
    integral,
    monitors=list('Vmhn'),
    inits=[0., 0., 0., 0.],
    args=dict(Iext=Iext, gNa=gNa, ENa=ENa, gK=gK, EK=EK, gL=gL, EL=EL, C=C),
    dt=0.01
)
runner.run(100.)

plt.subplot(211)
plt.plot(runner.mon.ts, runner.mon.V, label='V')
plt.legend()
plt.subplot(212)
plt.plot(runner.mon.ts, runner.mon.m, label='m')
plt.plot(runner.mon.ts, runner.mon.h, label='h')
plt.plot(runner.mon.ts, runner.mon.n, label='n')
plt.legend()
plt.show()
_images/0638e06f31efcc92c216855969bd3e877c4fb488496fffa9470e2a59a58b4544.png

Provided ODE Numerical Solvers#

BrainPy provides several types of numerical methods for ODEs, including explicit Runge-Kutta methods, adaptive Runge-Kutta methods, and Exponential Euler methods.

1. Explicit Runge-Kutta (RK) methods for ODEs#

The first category of ODE numerical integration support is the explicit Runge-Kutta (RK) methods. RK methods are a huge family of numerical methods with a wide variety of trade-offs: efficiency, accuracy, stability, etc. The supported RK methods are listed in the following table:

Methods

Keywords

Euler

euler

Midpoint

midpoint

Heun’s second-order method

heun2

Ralston’s second-order method

ralston2

RK2

rk2

RK3

rk3

RK4

rk4

Heun’s third-order method

heun3

Ralston’s third-order method

ralston3

Third-order Strong Stability Preserving Runge-Kutta

ssprk3

Ralston’s fourth-order method

ralston4

Runge-Kutta 3/8-rule fourth-order method

rk4_38rule

Users can utilize these methods by specifying the method option in brainpy.odeint() with their corresponding keyword. For example:

@bp.odeint(method='rk4')
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abc88be50>

Or, you can directly instance your favorite integrator:

@bp.ode.RK4
def int_v(v, t, p):
    # do something
    return v

int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abca4daf0>
def derivative(v, t, p):
    # do something
    return v

int_v = bp.ode.RK4(derivative, dt=0.01)
int_v
<brainpy.integrators.ode.explicit_rk.RK4 at 0x15abc88bca0>

2. Adaptive Runge-Kutta (RK) methods for ODEs#

The second category of ODE numerical support is the adaptive RK methods. What’s different from the explicit RK methods is that adaptive methods are designed to produce an estimate of the local truncation error in a single Runge-Kutta step, then such error can be used to adaptively control the numerical step size. Specifically, if \(error > tol\), then replace \(dt\) with \(dt_{new}\) and repeat the step. Therefore, adaptive RK methods allow a varied step size. In BrainPy, the following adaptive RK methods are provided in BrainPy:

Methods

keywords

Runge–Kutta–Fehlberg 4(5)

rkf45

Runge–Kutta–Fehlberg 1(2)

rkf12

Dormand–Prince method

rkdp

Cash–Karp method

ck

Bogacki–Shampine method

bs

Heun–Euler method

heun_euler

In default, the above methods are not adaptive, unless users provide a keyword adaptive=True in brainpy.odeint(). When users use the adaptive RK methods for numerical integration, the instantaneously adjusted stepsize dt will be appended in the functional arguments. Moreover, the tolerance tol for stepsize adjustment can also be modified. Let’s take the Lorenz system as the example:

# adaptively adjust step-size

@bm.jit
@bp.odeint(method='rkf45', 
           adaptive=True, # active the "adaptive" option
           tol=0.001) # set the tolerance
def lorenz(x, y, z, t, sigma, beta, rho):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz
times = bm.arange(0, 100, 0.01)
hist_x, hist_y, hist_z, hist_dt = [], [], [], []
x, y, z, dt = bm.array([1]), bm.array([1]), bm.array([1]), 0.05
for t in times:
    # should provide one more argument "dt" when using the adaptive rk method
    x, y, z, dt = lorenz(x, y, z, t, sigma=10, beta=8/3, rho=28, dt=dt)  
    hist_x.append(x.value)
    hist_y.append(y.value)
    hist_z.append(z.value)
    hist_dt.append(dt)
hist_x = bm.array(hist_x).flatten()
hist_y = bm.array(hist_y).flatten()
hist_z = bm.array(hist_z).flatten()
hist_dt = bm.array(hist_dt)
fig = plt.figure()
ax = plt.subplot(projection='3d')
plt.plot(hist_x, hist_y, hist_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')

fig = plt.figure()
plt.plot(hist_dt[:100])
plt.xlabel('Step No.')
plt.ylabel('Adaptive dt')
plt.show()
_images/6958da492e220291e4e29a2f7d08e7750f7f74ebf3a749e96a15891bc2164082.png _images/405aef29e9edbf7fe787f5b7a39e48c427d72155c89e07a65df744859d2df3a4.png

3. Exponential Euler methods for ODEs#

Finally, BrainPy provides Exponential integrators for ODEs. For you ODE systems, we highly recommend you to use Exponential Euler methods. Exponential Euler method provided in BrainPy uses automatic differentiation to find linear part.

Methods

keywords

Exponential Euler

exp_euler

Let’s take a linear system as the theoretical demonstration,

\[ {dy \over dt} = A - By \]

the exponential Euler schema is given by:

\[ y(t+dt) = y(t) e^{-B*dt} + {A \over B}(1 - e^{-B*dt}) \]

As you can see, for such linear systems, the exponential Euler schema is nearly the exact solution.

However, using Exponential Euler method requires us to write each derivative function separately. Otherwise, the automatic differentiation will lead to wrong results.

Interestingly, the computational expensive neuron model — Hodgkin–Huxley model — is a linear-like ODE system. You will find that by using the Exponential Euler method, the numerical step can be greatly enlarged to save the computation time.

\[\begin{split} \begin{aligned} C_{m}{\frac {d V}{dt}}&= -\left[{\bar {g}}_{\text{K}}n^{4} + {\bar {g}}_{\text{Na}}m^{3}h + {\bar {g}}_{l} \right] V +{\bar {g}}_{\text{K}}n^{4} V_{K} + {\bar {g}}_{\text{Na}}m^{3}h V_{Na} + {\bar {g}}_{l} V_{l} + I_{syn} \\ {\frac {dm}{dt}} &= \left[-\alpha _{m}(V)-\beta _{m}(V)\right]m + \alpha _{m}(V) \\ {\frac {dh}{dt}} &= \left[-\alpha _{h}(V)-\beta _{h}(V)\right]h + \alpha _{h}(V) \\ {\frac {dn}{dt}} &= \left[-\alpha _{n}(V)-\beta _{n}(V)\right]n + \alpha _{n}(V) \\ \end{aligned} \end{split}\]
Iext=10.;   ENa=50.;   EK=-77.;   EL=-54.387
C=1.0;      gNa=120.;  gK=36.;    gL=0.03
def dm(m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return dmdt
def dh(h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return dhdt
def dn(n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return dndt
def dV(V, t, m, h, n, Iext):
    I_Na = (gNa * m ** 3.0 * h) * (V - ENa)
    I_K = (gK * n ** 4.0) * (V - EK)
    I_leak = gL * (V - EL)
    dVdt = (- I_Na - I_K - I_leak + Iext) / C
    return dVdt

Although we define HH differential equations as separable functions, relying on brainpy.JointEq, we can numerically integrate these equations jointly.

hh_derivative = bp.JointEq([dV, dm, dh, dn])
def run(method, Iext=10., dt=0.1):
    integral = bp.odeint(hh_derivative, method=method)

    runner = bp.IntegratorRunner(
        integral,
        monitors=list('Vmhn'),
        inits=[0., 0., 0., 0.],
        args=dict(Iext=Iext),
        dt=dt
    )
    runner.run(100.)

    plt.subplot(211)
    plt.plot(runner.mon.ts, runner.mon.V, label='V')
    plt.legend()
    plt.subplot(212)
    plt.plot(runner.mon.ts, runner.mon.m, label='m')
    plt.plot(runner.mon.ts, runner.mon.h, label='h')
    plt.plot(runner.mon.ts, runner.mon.n, label='n')
    plt.legend()
    plt.show()

Euler Method: not able to complete the integral when the time step is a bit larger

run('euler', Iext=10, dt=0.02)
_images/53bb95044f9fa062159814c93ee0ea2193c8825fa6926c18b2c4cf4ea8bdc2ed.png
run('euler', Iext=10, dt=0.1)
_images/b7c5da9795079b6bc5ea90b2079da3fcd67f9b7480b5f212b67dbc8116570328.png

RK4 Method: better than the Euler method, but still requires the times step to be small

run('rk4', Iext=10, dt=0.1)
_images/2bc0a0c5b93272ee97554f48e2ffc76369a686da4282d906d8a1e9cc28c86591.png
run('rk4', Iext=10, dt=0.2)
_images/b3576aa64e02125f901cb2bfd0e698be89103636d04f71287534aa1e060cb46b.png

Exponential Euler Method: allows larger time step and generates accurate results

run('exp_euler', Iext=10, dt=0.2)
_images/a564256740276af5e7532922b3d5f357e08e2c79a12c5590ab6791280bcf1cf2.png

Numerical Solvers for Stochastic Differential Equations#

@Chaoming Wang

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

import brainpy as bp

bp.__version__
'2.3.0'
import matplotlib.pyplot as plt

How to define SDE functions?#

For a one-dimensional stochastic differentiable equation (SDE) with scalar Wiener noise, it is given by

\[ \begin{aligned} d X_{t}&=f\left(X_{t}, t, p_1\right) d t+g\left(X_{t}, t, p_2\right) d W_{t} \quad (1) \end{aligned} \]

where \(X_t = X(t)\) is the realization of a stochastic process or random variable, \(f(X_t, t)\) is the drift coefficient, \(g(X_t, t)\) denotes the diffusion coefficient, the stochastic process \(W_t\) is called Wiener process.

For this SDE system, we can define two Python funtions \(f\) and \(g\) to represent it.

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df

Same with the ODE functions, the arguments before \(t\) denotes the random variables, while the arguments defined after \(t\) represents the parameters. For the SDE function with scalar noise, the size of the return data \(dg\) and \(df\) should be the same. For example, \(df \in R^d, dg \in R^d\).

However, for a more general SDE system, it usually has multi-dimensional driving Wiener process:

\[ dX_t=f(X_t)dt+\sum_{\alpha=1}^{m}g_{\alpha }(X_t)dW_t ^{\alpha} \]

For such \(m\)-dimensional noise system, the coding schema is the same with the scalar ones, but with the difference of that the data size of \(dg\) has one more dimension. For example, \(df \in R^{d}, dg \in R^{m \times d}\).

How to define the numerical integration for SDEs?#

Brefore the numerical integration of SDE functions, we should distinguish two kinds of SDE integrals. For the integration of system (1), we can get

\[ \begin{aligned} X_{t}&=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} \quad (2) \end{aligned} \]

In 1940s, the Japanese mathematician K. Ito denoted a type of integral called Ito stochastic integral. In 1960s, the Russian physicist R. L. Stratonovich proposed an other kind of stochastic integral called Stratonovich stochastic integral and used the symbol “\(\circ\)” to distinct it from the former Ito integral.

\[\begin{split} \begin{aligned} d X_{t} &=f\left(X_{t}, t\right) d t+g\left(X_{t}, t\right) \circ d W_{t} \\ X_{t} &=X_{t_{0}}+\int_{t_{0}}^{t} f\left(X_{s}, s\right) d s+\int_{t_{0}}^{t} g\left(X_{s}, s\right) \circ d W_{s} \quad (3) \end{aligned} \end{split}\]

The difference of Ito integral (2) and Stratonovich integral (3) lies at the second integral term, which can be written in a general form as

\[\begin{split} \begin{split} \int_{t_{0}}^{t} g\left(X_{s}, s\right) d W_{s} &=\lim _{h \rightarrow 0} \sum_{k=0}^{m-1} g\left(X_{\tau_{k}}, \tau_{k}\right)\left(W\left(t_{k+1}\right)-W\left(t_{k}\right)\right) \\ \mathrm{where} \quad & h = t_{k+1} - t_{k} \\ & \tau_k = (1-\lambda)t_k +\lambda t_{k+1} \end{split} \end{split}\]
  • In the stochastic integral of the Ito SDE, \(\lambda=0\), thus \(\tau_k=t_k\);

  • In the definition of the Stratonovich integral, \(\lambda=0.5\), thus \(\tau_k=(t_{k+1} + t_{k}) / 2\).

In BrainPy, these two different integrals can be easily implemented. What need the users do is to provide a keyword sde_type in decorator bp.sdeint. intg_type can be “bp.integrators.STRA_SDE” or “bp.integrators.ITO_SDE” (default). Also, the different type of Wiener process can also be easily distinguished by the wiener_type keyword. It can be “bp.integrators.SCALAR_WIENER” (default) or “bp.integrators.VECTOR_WIENER”.

Now, let’s numerically integrate the SDE (1) by the Ito way with the Milstein method:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

@bp.sdeint(g=g_part, method='milstein')
def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Or, it can be expressed as:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(d,)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, g=g_part, method='milstein')

However, if you try to numerically integrate the SDE with multi-dimensional Wiener process by the Stratonovich ways, you can code it like this:

def g_part(x, t, p1, p2):
    dg = g(x, t, p2)
    return dg  # shape=(m, d)

def f_part(x, t, p1, p2):
    df = f(x, t, p1)
    return df  # shape=(d,)

integral = bp.sdeint(f=f_part, 
                     g=g_part, 
                     method='milstein', 
                     intg_type=bp.integrators.STRA_SDE, 
                     wiener_type=bp.integrators.VECTOR_WIENER)

Example: Noisy Lorenz system#

Here, let’s demenstrate how to define a numerical solver for SDEs with the famous Lorenz system:

\[\begin{split} \begin{array}{l} \frac{d x}{dt}&=\sigma(y-x) &+ px*\xi_x \\ \frac{d y}{dt}&=x(\rho-z)-y &+ py*\xi_y\\ \frac{d z}{dt}&=x y-\beta z &+ pz*\xi_z \end{array} \end{split}\]
sigma = 10; beta = 8/3; 
rho = 28;   p = 0.1

def lorenz_g(x, y, z, t):
    return p * x, p * y, p * z

def lorenz_f(x, y, z, t):
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return dx, dy, dz

lorenz = bp.sdeint(f=lorenz_f, 
                   g=lorenz_g, 
                   intg_type=bp.integrators.ITO_SDE,
                   wiener_type=bp.integrators.SCALAR_WIENER)

To run this integrator, we use brainpy.integrators.IntegratorRunner, which can JIT compile the model to gain impressive speed.

runner = bp.IntegratorRunner(
    lorenz,
    monitors=['x', 'y', 'z'],
    inits=[1., 1., 1.],
    dt=0.001
)
runner.run(50.)

fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
_images/d6939f634d69e544c824422b3be44a47e476257ae627f0d5e7b22788b971245d.png

We can also rewrite the above differential equation as a JointEq of separable equations, so that it can be applied to Exponential Euler method.

dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
lorenz_f = bp.JointEq(dx, dy, dz)
lorenz = bp.sdeint(f=lorenz_f,
                   g=lorenz_g,
                   intg_type=bp.integrators.ITO_SDE,
                   wiener_type=bp.integrators.SCALAR_WIENER,
                   method='exp_euler')

runner = bp.IntegratorRunner(
    lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.], dt=0.001
)
runner.run(50.)

plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.show()
_images/baa8d9eb50feea119cb9fd7f7abcb88b12c3971d5f3e18d4e97d6fe1a13a6ed1.png

Supported SDE Numerical Methods#

BrainPy provides several numerical methods for stochastic differential equations (SDEs). Specifically, we provide explicit Runge-Kutta methods, derivative-free Milstein methods, and exponential Euler method for SDE numerical integration.

Methods

Keywords

Ito SDE support

Stratonovich SDE support

Scalar Wiener support

Vector Wiener support

Strong SRK scheme: SRI1W1

srk1w1_scalar

Yes

Yes

Strong SRK scheme: SRI2W1

srk2w1_scalar

Yes

Yes

Strong SRK scheme: KlPl

KlPl_scalar

Yes

Yes

Euler method

euler

Yes

Yes

Yes

Yes

Heun method

heun

Yes

Yes

Yes

Milstein

milstein

Yes

Yes

Yes

Yes

Derivative-free Milstein

milstein_grad_free

Yes

Yes

Yes

Yes

Exponential Euler

exp_euler

Yes

Yes

Yes

Numerical Solvers for Fractional Differential Equations#

@Chaoming Wang

import matplotlib.pyplot as plt

import brainpy as bp
import brainpy.math as bm

bp.__version__
'2.3.0'

Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:

  • Grünwald-Letnikov derivative

  • Caputo fractional derivative

See Fractional calculus - Wikipedia for more details.

Methods for Caputo FDEs#

For a given fractional differential equation

\[ \frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t) \]

where the fractional order \(0<\alpha\le 1\). BrainPy provides two kinds of methods:

  • Euler method - brainpy.fde.CaputoEuler

  • L1 schema integration - brainpy.fde.CaputoL1Schema

brainpy.fde.CaputoEuler#

brainpy.fed.CaputoEuler provides one-step Euler method for integrating Caputo fractional differential equations.

Given a fractional-order Qi chaotic system

\[\begin{split} \left\{\begin{array}{l} D^{\alpha} x_{1}=a\left(x_{1}-x_{2}\right)+x_{2} x_{3} \\ D^{\alpha} x_{2}=c x_{1}-x_{2}-x_{1} x_{3} \\ D^{\alpha} x_{3}=x_{1} x_{2}-b x_{3} \end{array}\right. \end{split}\]

we can solve the equation system by:

a, b, c = 35, 8 / 3, 80


def qi_system(x, y, z, t):
  dx = -a * x + a * y + y * z
  dy = c * x - y - x * z
  dz = -b * z + x * y
  return dx, dy, dz
dt = 0.001
duration = 50
inits = [0.1, 0.2, 0.3]

# The numerical integration of FDE need to know all
# history information, therefore, we need provide
# the overall simulation time "num_step" to save
# all history values.
integrator = bp.fde.CaputoEuler(qi_system,
                                alpha=0.98,  # fractional order
                                num_memory=int(duration / dt),
                                inits=inits)

runner = bp.IntegratorRunner(integrator,
                             monitors=list('xyz'),
                             inits=inits,
                             dt=dt)
runner.run(duration)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()
_images/f95d9de76c81a6572ee75a67e9a656604aa0da2d36a24467a7ab77b0de14bfbe.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/910e2f23eb5aca7388f5db5e36f5c1ba1e1c26e9f5096802d7acdeea6b9d46b9.png

brainpy.fde.CaputoL1Schema#

brainpy.fed.CaputoL1Schema is another commonly used method to integrate Caputo derivative equations. Let’s try it with a fractional-order Lorenz system, which is given by:

\[\begin{split} \left\{\begin{array}{l} D^{\alpha} x=a\left(y-x\right) \\ D^{\alpha} y= x * (b - z) - y \\ D^{\alpha} z =x * y - c * z \end{array}\right. \end{split}\]
a, b, c = 10, 28, 8 / 3


def lorenz_system(x, y, z, t):
  dx = a * (y - x)
  dy = x * (b - z) - y
  dz = x * y - c * z
  return dx, dy, dz
dt = 0.001
duration = 50
inits = [1, 2, 3]

integrator = bp.fde.CaputoL1Schema(lorenz_system,
                                   alpha=0.99,  # fractional order
                                   num_memory=int(duration / dt),
                                   inits=inits)

runner = bp.IntegratorRunner(integrator,
                             monitors=list('xyz'),
                             inits=inits,
                             dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.y)
plt.show()
_images/60f7408bffdfc5111c92580e598c05455082d516406e9952411c7b060400efb3.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/bf44ff13321d76b798f263ecc1d1eefbf9983f3eb019737d5feaa378fc106bcb.png

Methods for Grünwald-Letnikov FDEs#

Grünwald-Letnikov FDE is another commonly-used type in neuroscience. Here, we provide a efficient computation method according to the short-memory principle in Grünwald-Letnikov method.

brainpy.fde.GLShortMemory#

brainpy.fde.GLShortMemory is highly efficient, because it does not require infinity memory length for numerical solution. Due to the decay property of the coefficients, brainpy.fde.GLShortMemory implements a limited memory length to reduce the computational time. Specifically, it only relies on the memory window of num_memory length. With the increasing width of memory window, the accuracy of numerical approximation will increase.

Here, we demonstrate it by using a fractional-order Chua system, which is defined as

\[\begin{split} \left\{\begin{array}{l} D^{\alpha_{1}} x=a\{y- (1+m_1) x-0.5*(m_0-m_1)*(|x+1|-|x-1|)\} \\ D^{\alpha_{2}} y=x-y+z \\ D^{\alpha_{3}} z=-b y-c z \end{array}\right. \end{split}\]
a, b, c = 10.725, 10.593, 0.268
m0, m1 = -1.1726, -0.7872


def chua_system(x, y, z, t):
  f = m1 * x + 0.5 * (m0 - m1) * (abs(x + 1) - abs(x - 1))
  dx = a * (y - x - f)
  dy = x - y + z
  dz = -b * y - c * z
  return dx, dy, dz
dt = 0.001
duration = 200
inits = [0.2, -0.1, 0.1]

integrator = bp.fde.GLShortMemory(chua_system,
                                  alpha=[0.93, 0.99, 0.92],
                                  num_memory=1000,
                                  inits=inits)

runner = bp.IntegratorRunner(integrator,
                             monitors=list('xyz'),
                             inits=inits,
                             dt=dt)
runner.run(duration)
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.x, runner.mon.z)
plt.show()
_images/e540ca614b30533afb874d0004f52bbd136e804cb47055353a9e82720bf922ea.png
plt.figure(figsize=(10, 8))
plt.plot(runner.mon.y, runner.mon.z)
plt.show()
_images/d4f35745b711c56911206cbaceb476ce92a7a115d370a2c32151423c39137d0e.png

Actually, the coefficient used in brainpy.fde.GLWithMemory can be inspected through:

plt.figure(figsize=(10, 6))
coef = integrator.binomial_coef
alphas = bm.as_numpy(integrator.alpha)

plt.subplot(211)
for i in range(3):
  plt.plot(coef[:, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.subplot(212)
for i in range(3):
  plt.plot(coef[:10, i], label=r'$\alpha$=' + str(alphas[i]))
plt.legend()
plt.show()
_images/fe2131e6d2e8caad99147e218dfb546631034522edc37cf8ede2bb968afabd67.png

As you see, the coefficients decay very quickly!

Further reading#

More examples of how to use numerical solvers of fractional differential equations defined in BrainPy, please see:

Numerical Solvers for Delay Differential Equations#

@Chaoming Wang

In real world systems, delay is very often encountered in many practical systems, such as automatic control, biology, economics and long transmission lines. The delayed differential equation (DDEs) is used to describe these dynamical systems.

Delay differential equations (DDEs) are a type of differential equation in which the derivative at a certain time is given in terms of the values of the function at previous times.

Let’s take delay ODEs as the example. The simplest constant delay equations have the form

\[ y'(t) = f(t, y(t), y(t-\tau_1), y(t-\tau_2),\ldots, y(t-\tau_k)) \]

where the time delays (lags) \(\tau_j\) are positive constants.

For neutral type DDE delays appear in derivative terms,

\[ y'(t) = f(t, y(t), y'(t-\tau_1), y'(t-\tau_2),\ldots, y'(t-\tau_k)) \]

More generally, state dependent delays may depend on the solution, that is \(\tau_i = \tau_i (t,y(t))\).

In BrainPy, we support delay differential equations based on delay variables. Specifically, for state-dependent delays, we have:

  • brainpy.math.TimeDelay

  • brainpy.math.LengthDelay

For neutral-type delays, we use:

  • brainpy.math.NeuTimeDelay

  • brainpy.math.NeuLenDelay

import matplotlib.pyplot as plt

import brainpy as bp
import brainpy.math as bm

bm.enable_x64()

bp.__version__
'2.3.0'

Delay variables#

For an ODE system, the numerical methods need to know its initial condition \(y(t_0)=y_0\) and its derivative rule \(y'(t_0) = y'_0\). However, for DDEs, it is not enough to give a set of initial values for the function and its derivatives at \(t_0\), but one must give a set of functions to provide the historical values for \(t_0 - max(\tau) \leq t \leq t_0\).

Therefore, you need some delay variables to wrap the variable delays. brainpy.math.TimeDelay can be used to define delay variables which depend on states, and brainpy.math.NeuTimeDelay is used to define delay variables which depend on the derivative.

d = bm.TimeDelay(bm.zeros(2), delay_len=10, dt=1, t0=0, before_t0=lambda t: t)
d(0.)
Array([0., 0.], dtype=float64)
d(-0.5)
Array([-0.5, -0.5], dtype=float64)

Request a time beyond \((max\_delay, t_0)\) will cause an error.

try:
  d(0.1)
except Exception as e:
  print(e)
ERROR:jax.experimental.host_callback:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x0000016DB6F87280> threw exception The request time should be less than the current time 0.0. But we got 0.1 > 0.0.
INTERNAL: Generated function failed: CpuCallback error: ValueError: The request time should be less than the current time 0.0. But we got 0.1 > 0.0

At:
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\math\delayvars.py(202): _check_time1
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\tools\errors.py(27): <lambda>
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(721): __call__
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(1295): _outside_call_run_callback
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\experimental\host_callback.py(1164): wrapped_callback
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\mlir.py(1765): _wrapped_callback
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\mlir.py(1790): _wrapped_callback
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\interpreters\pxla.py(2136): __call__
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\profiler.py(314): wrapper
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\dispatch.py(120): apply_primitive
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(712): process_primitive
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(332): bind_with_trace
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\core.py(2449): bind
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(780): cond_bind
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(255): _cond
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\lax\control_flow\conditionals.py(274): cond
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\traceback_util.py(162): reraise_with_filtered_traceback
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\tools\errors.py(29): jit_error_checking
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\brainpy-2.3.0-py3.9.egg\brainpy\math\delayvars.py(216): __call__
  C:\Users\adadu\AppData\Local\Temp\ipykernel_35084\1658324099.py(2): <cell line: 1>
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3398): run_code
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3338): run_ast_nodes
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(3135): run_cell_async
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\async_helpers.py(129): _pseudo_sync_runner
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(2936): _run_cell
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\IPython\core\interactiveshell.py(2881): run_cell
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\zmqshell.py(528): run_cell
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\ipkernel.py(383): do_execute
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(730): execute_request
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(406): dispatch_shell
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(499): process_one
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelbase.py(510): dispatch_queue
  C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\events.py(80): _run
  C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\base_events.py(1905): _run_once
  C:\Users\adadu\miniconda3\envs\brainpy\lib\asyncio\base_events.py(601): run_forever
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\tornado\platform\asyncio.py(199): start
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel\kernelapp.py(712): start
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\traitlets\config\application.py(846): launch_instance
  C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\ipykernel_launcher.py(17): <module>
  C:\Users\adadu\miniconda3\envs\brainpy\lib\runpy.py(87): _run_code
  C:\Users\adadu\miniconda3\envs\brainpy\lib\runpy.py(197): _run_module_as_main

Delay ODEs#

Here we illustrate how to make numerical integration of delay ODEs with several examples. Before that, we define a general function to simulate a delay ODE function.

def delay_odeint(duration, eq, args=None, inits=None,
                 state_delays=None, neutral_delays=None,
                 monitors=('x',), method='euler', dt=0.1):
  # define integrators of ODEs based on `brainpy.odeint`
  dde = bp.odeint(eq,
                  state_delays=state_delays,
                  neutral_delays=neutral_delays,
                  method=method)
  # define IntegratorRunner
  runner = bp.IntegratorRunner(dde,
                               args=args,
                               monitors=monitors,
                               dt=dt,
                               inits=inits)
  runner.run(duration)
  return runner.mon

Example #1: First-order DDE with one constant delay and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t-1) \]

where the delay is 1 s. the example compares the solutions of three different cases using three different constant history functions:

  • Case #1: \(\phi(t)=-1\)

  • Case #2: \(\phi(t)=0\)

  • Cas3 #3: \(\phi(t)=1\)

def equation(x, t, xdelay):
  return -xdelay(t - 1)


case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round')
case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=0., interp_method='round')
case3_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=1., interp_method='round')
case1 = delay_odeint(20., equation, args={'xdelay': case1_delay},
                     state_delays={'x': case1_delay})  # delay for variable "x"
case2 = delay_odeint(20., equation, args={'xdelay': case2_delay}, state_delays={'x': case2_delay})
case3 = delay_odeint(20., equation, args={'xdelay': case3_delay}, state_delays={'x': case3_delay})
fig, axs = plt.subplots(3, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1)$")

axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')

axs[1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=0$')

axs[2].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[2].set_title('$ihf(t)=1$')

plt.show()
_images/303d77dfea1d889691c7dbd4d4bb883ca264145cda0c67d0a2942914914ffeeb.png

Example #2: First-order DDE with one constant delay and a non constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t-2) \]

where the delay is 2 s; the example compares the solutions of four different cases using two different non constant history functions and two different intervals of \(t\):

  • Case #1: \(\phi(t)=e^{-t} - 1, t \in [0, 4]\)

  • Case #2: \(\phi(t)=e^{t} - 1, t \in [0, 4]\)

  • Case #3: \(\phi(t)=e^{-t} - 1, t \in [0, 60]\)

  • Case #4: \(\phi(t)=e^{t} - 1, t \in [0, 60]\)

def eq(x, t, xdelay):
  return -xdelay(t - 2)


delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t) - 1, dt=0.01, interp_method='round')
delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t) - 1, dt=0.01, interp_method='round')
delay3 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(-t) - 1, dt=0.01, interp_method='round')
delay4 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: bm.exp(t) - 1, dt=0.01, interp_method='round')
case1 = delay_odeint(4., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(4., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
case3 = delay_odeint(60., eq, args={'xdelay': delay3}, state_delays={'x': delay3}, dt=0.01)
case4 = delay_odeint(60., eq, args={'xdelay': delay4}, state_delays={'x': delay4}, dt=0.01)
fig, axs = plt.subplots(2, 2)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-2)$")

axs[0, 0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 4]$')

axs[0, 1].plot(case2.ts, case2.x, color='red', linewidth=1)
axs[0, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 4]$')

axs[1, 0].plot(case3.ts, case3.x, color='red', linewidth=1)
axs[1, 0].set_title('$ihf(t)=e^{-t} - 1, t \in [0, 60]$')

axs[1, 1].plot(case4.ts, case4.x, color='red', linewidth=1)
axs[1, 1].set_title('$ihf(t)=e^t - 1, t \in [0, 60]$')

plt.show()
_images/4d2e6be4b99fcea306bd6d71e317537fc32c91de7e9725a20b5a83aa1cecebb4.png

Example #3: First-order DDE with two constant delays and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=-y(t - 1) + 0.3 y(t - 2) \]

where the delays are two and are both constants equal to 1s and 2s respectively; The initial historical function is also constant and is \(\phi(t)=1\).

def eq(x, t):
  return -delay(t - 1) + 0.3 * delay(t - 2)


delay = bm.TimeDelay(bm.ones(1), 2., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(10., eq, inits=[1.], state_delays={'x': delay}, dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=-y(t-1) + 0.3\ y(t-2)$")

axs.plot(mon.ts, mon.x, color='red', linewidth=1)
axs.set_title('$ihf(t)=1$')

plt.show()
_images/4169a22b9b1506f56ceb45847ee61f1ddd2231cc7aa64d207cccfbb607f07965.png

Example #4: System of two first-order DDEs with one constant delay and two constant initial history functions#

Let the following system of DDEs be given:

\[\begin{split} \begin{cases} y_1'(t) = y_1(t) y_2(t-0.5) \\ y_2'(t) = y_2(t) y_1(t-0.5) \end{cases} \end{split}\]

where the delay is only one, constant and equal to 0.5 s and the initial historical functions are also constant; for what we said at the beginning of the post these must be two, in fact being the order of the system of first degree you need one for each unknown and they are: \(y_1(t)=1, y_2(t)=-1\).

def eq(x, y, t):
  dx = x * ydelay(t - 0.5)
  dy = y * xdelay(t - 0.5)
  return dx, dy


xdelay = bm.TimeDelay(bm.ones(1), 0.5, before_t0=1., dt=0.01, interp_method='round')
ydelay = bm.TimeDelay(-bm.ones(1), 0.5, before_t0=-1., dt=0.01, interp_method='round')

mon = delay_odeint(3., eq, inits=[1., -1], state_delays={'x': xdelay, 'y': ydelay},
                   dt=0.01, monitors=['x', 'y'])
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$x'(t)=x(t) y(t-d); y'(t)=y(t) x(t-d)$")

axs.plot(mon.ts, mon.x.flatten(), color='red', linewidth=1)
axs.plot(mon.ts, mon.y.flatten(), color='blue', linewidth=1)
axs.set_title('$ihf_x(t)=1; ihf_y(t)=-1; d=0.5$')

plt.show()
_images/e42751a377ec45c32331e6d490f7ab3573fbcb8a601aa9baa7457e9057dd5768.png

Example #5: Second-order DDE with one constant delay and two constant initial history functions#

Let the following DDE be given:

\[ y(t)'' = -y'(t) - 2y(t) - 0.5 y(t-1) \]

where the delay is only one, constant and equal to 1 s. Since the DDE is second order, in that the second derivative of the unknown function appears, the historical functions must be two, one to give the values of the unknown \(y(t)\) for \(t <= 0\), and one and one to provide the value of the first derivative \(y'(t)\) also for \(t <= 0\).

In this example they are the following two constant functions: \(y(t)=1, y'(t)=0\).

Due to the properties of the second-order equations, the given DDE is equivalent to the following system of first-order equations:

\[\begin{split} \begin{cases} y_1'(t) = y_2(t) \\ y_2'(t) = -y_1'(t) - 2y_1(t) - 0.5 y_1(t-1) \end{cases} \end{split}\]

and so the implementation falls into the case of the previous example of systems of first-order equations.

def eq(x, y, t):
  dx = y
  dy = -y - 2 * x - 0.5 * xdelay(t - 1)
  return dx, dy


xdelay = bm.TimeDelay(bm.ones(1), 1., before_t0=1., dt=0.01, interp_method='round')
mon = delay_odeint(16., eq, inits=[1., 0.], state_delays={'x': xdelay}, monitors=['x', 'y'], dt=0.01)
fig, axs = plt.subplots(1, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y''(t)=-y'(t) - 2 y(t) - 0.5 y(t-1)$")
axs.plot(mon.ts, mon.x[:, 0], color='red', linewidth=1)
axs.plot(mon.ts, mon.y[:, 0], color='green', linewidth=1)
axs.set_title('$ih \, f_y(t)=1; ihf\,dy/dt(t)=0$')

plt.show()
_images/841eec1d11ae1d07544b4c112244311ddd6e92fd273df189c24227ab5e914c25.png

Example #6: First-order DDE with one non constant delay and a constant initial history function#

Let the following DDE be given:

\[ y'(t)=y(t-\mathrm{delay}(y, t)) \]

where the delay is not constant and is given by the function \(\mathrm{delay}(y, t)=|\frac{1}{10} t y(\frac{1}{10} t)|\), the example compares the solutions of two different cases using two different constant history functions:

  • Case #1: \(\phi(t)=-1\)

  • Case #2: \(\phi(t)=1\)

def eq(x, t, xdelay):
  delay = abs(t * xdelay(t - 0.9 * t) / 10)  # a tensor with (1,)
  delay = delay[0]
  return xdelay(t - delay)

Note

Note here we do not kwon the maximum lenght of the delay. Therefore, we can declare a fixed length delay variable with the delay_len equal to or even bigger than the running duration.

delay1 = bm.TimeDelay(bm.ones(1), 30., before_t0=-1, dt=0.01)
delay2 = bm.TimeDelay(-bm.ones(1), 30., before_t0=1, dt=0.01)
case1 = delay_odeint(30., eq, args={'xdelay': delay1}, state_delays={'x': delay1}, dt=0.01)
case2 = delay_odeint(30., eq, args={'xdelay': delay2}, state_delays={'x': delay2}, dt=0.01)
fig, axs = plt.subplots(2, 1)
fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0)
fig.suptitle("$y'(t)=y(t-delay(y, t))$")

axs[0].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[0].set_title('$ihf(t)=-1$')

axs[1].plot(case1.ts, case1.x, color='red', linewidth=1)
axs[1].set_title('$ihf(t)=1$')

plt.show()
_images/a6678ca1c246ff5c4bc303ef11512de84289c53839674595a307ee43ebaa930a.png

Delay SDEs#

Save as delay ODEs, state-dependent delay variables can be appended into state_delay argument in brainpy.sdeint function.

delay = bm.TimeDelay(bm.zeros(1),
                     2.,
                     before_t0=lambda t: bm.exp(-t) - 1,
                     dt=0.01,
                     interp_method='round')

f = lambda x, t: -delay(t - 2)
g = lambda x, t, *args: 0.01

dt = 0.01
integral = bp.sdeint(f, g, state_delays={'x': delay})
runner = bp.IntegratorRunner(integral,
                             monitors=['x'],
                             dt=dt)
runner.run(100.)

plt.plot(runner.mon.ts, runner.mon.x)
plt.show()
_images/c61369e9576dbed600ce24ad60e9ff05d7ca1f66a3cf608072919cc17740b5df.png

Delay FDEs#

Fractional order delayed differential equations as the generalization of the delayed differential equations, provide more freedom when we’re describing these systems, let’s see how we can use BrainPy to accelerate the simulation of fractional order delayed differential equations.

The fractional delayed differential equations has the general form:

\[\begin{split} \begin{gathered} D_{t}^{\alpha} y(t)=f(t, y(t), y(t-\tau)), \quad t \geq \xi \\ y(t)=\phi(t), \quad t \in[\xi-\tau, \xi] \end{gathered} \end{split}\]

Lemmings’ population cycle#

The fractional order version of the four-year life cycle of a population of lemmings is given by

\begin{split} \begin{gathered} D_{t}^{\alpha} y(t)=3.5 y(t)\left(1-\frac{y(t-0.74)}{19}\right), \ y(0)=19.00001 \ y(t)=19, t<0 \end{gathered} \end{split}

dt = 0.05
delay = bm.TimeDelay(bm.asarray([19.00001]), 0.74, before_t0=19., dt=dt)
f = lambda y, t: 3.5 * y * (1 - delay(t - 0.74) / 19)
integral = bp.fde.GLShortMemory(f,
                                alpha=0.97,
                                inits=[19.00001],
                                num_memory=500,
                                state_delays={'y': delay})
runner = bp.IntegratorRunner(integral,
                             inits=bm.asarray([19.00001]),
                             monitors=['y'],
                             fun_monitors={'y(t-0.74)': lambda t, _: delay(t - 0.74)},
                             dyn_vars=delay.vars(),
                             dt=dt)
runner.run(100.)

plt.plot(runner.mon['y'], runner.mon['y(t-0.74)'])
plt.xlabel('y(t)')
plt.ylabel('y(t-0.74)')
plt.show()
_images/058c48783d993059cb97de2e5f2a63a97357f3d067eeb275987d2c53c4a37e5e.png

Time delay Chen system#

Time delay Chen system as a famous chaotic system with time delay, has important applications in many fields.

\[\begin{split} \left\{\begin{array}{l} D^{\alpha_{1}} x=a(y(t)-x(t-\tau)) \\ D^{\alpha_{2}} y=(c-a) x(t-\tau)-x(t) z(t)+c y(t) \\ D^{\alpha_{3}} z=x(t) y(t)-b z(t-\tau) \end{array}\right. \end{split}\]
dt = 0.001
tau = 0.009
xdelay = bm.TimeDelay(bm.asarray([0.2]), tau, dt=dt)
zdelay = bm.TimeDelay(bm.asarray([0.5]), tau, dt=dt)


def derivative(x, y, z, t):
  a = 35;
  b = 3;
  c = 27
  dx = a * (y - xdelay(t - tau))
  dy = (c - a) * xdelay(t - tau) - x * z + c * y
  dz = x * y - b * zdelay(t - tau)
  return dx, dy, dz


integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.94,
                                inits=[0.2, 0., 0.5],
                                num_memory=500,
                                state_delays={'x': xdelay, 'z': zdelay})
runner = bp.IntegratorRunner(integral,
                             inits=[0.2, 0., 0.5],
                             monitors=['x', 'y', 'z'],
                             dyn_vars=xdelay.vars() + zdelay.vars(),
                             dt=dt)
runner.run(100.)
fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0])
plt.show()
_images/fb38065a81dd0cd224bedfe9f7e5ccca8f25c5f36f43b24ffe02dd4f898d3b19.png

Enzyme kinetics#

Let’s see a more complex example of the fractional order version of enzyme kinetics with an inhibitor molecule:

\[\begin{split} \begin{gathered} D_{t}^{\alpha} y_{1}(t)=10.5-\frac{y_{1}(t)}{1+0.0005 y_{4}^{3}(t-4)} \\ D_{t}^{\alpha} y_{2}(t)=\frac{y_{1}(t)}{1+0.0005 y_{4}^{3}(t-4)}-y_{2}(t) \\ D_{t}^{\alpha} y_{3}(t)=y_{2}(t)-y_{3}(t) \\ D_{t}^{\alpha} y_{4}(t)=y_{3}(t)-0.5 y_{4}(t) \\ y(t)=[60,10,10,20], t \leq 0 \end{gathered} \end{split}\]
dt = 0.01
tau = 4.
delay = bm.TimeDelay(bm.asarray([20.]), tau, before_t0=20, dt=dt)


def derivative(a, b, c, d, t):
  da = 10.5 - a / (1 + 0.0005 * delay(t - tau) ** 3)
  db = a / (1 + 0.0005 * delay(t - tau) ** 3) - b
  dc = b - c
  dd = c - 0.5 * d
  return da, db, dc, dd


integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.95,
                                inits=[60, 10, 10, 20],
                                num_memory=500,
                                state_delays={'d': delay})
runner = bp.IntegratorRunner(integral,
                             inits=[60, 10, 10, 20],
                             monitors=list('abcd'),
                             dyn_vars=delay.vars(),
                             dt=dt)
runner.run(200.)
plt.plot(runner.mon.ts, runner.mon.a, label='a')
plt.plot(runner.mon.ts, runner.mon.b, label='b')
plt.plot(runner.mon.ts, runner.mon.c, label='c')
plt.plot(runner.mon.ts, runner.mon.d, label='d')
plt.legend()
plt.xlabel('Time [ms]')
plt.show()
_images/5b7b4936074b51a7cb50dfa77a87f9fce78d2cf06275a82f1f62eb4eec8006ba.png

Fractional matrix delayed differential equations#

BrainPy is also capable of solving fractional matrix delayed differential equations:

\[ D_{t_{0}}^{\alpha} \mathbf{x}(t)=\mathbf{A}(t) \mathbf{x}(t)+\mathbf{B}(t) \mathbf{x}(t-\tau)+\mathbf{c}(t) \]

Here \(x(t)\) is vector of states of the system, \(c(t)\) is a known function of disturbance.

We explain the detailed usage by using an example:

\[\begin{split} \mathbf{x}(t)=\left(\begin{array}{l} x_{1}(t) \\ x_{2}(t) \\ x_{3}(t) \\ x_{4}(t) \end{array}\right) \end{split}\]
\[\begin{split} \mathbf{A}=\left(\begin{array}{cccc} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & -2 & 0 & 0 \\ -2 & 0 & 0 & 0 \end{array}\right) \end{split}\]
\[\begin{split} \mathbf{B}=\left(\begin{array}{cccc} 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ -2 & 0 & 0 & 0 \\ 0 & -2 & 0 & 0 \end{array}\right) \end{split}\]

With initial condition:

\[\begin{split} \mathbf{x}_{0}(t)=\left(\begin{array}{c} \sin (t) \cos (t) \\ \sin (t) \cos (t) \\ \cos ^{2}(t)-\sin ^{2}(t) \\ \cos ^{2}(t)-\sin ^{2}(t) \end{array}\right) \end{split}\]
dt = 0.01
tau = 3.1416
f = lambda t: bm.asarray([bm.sin(t) * bm.cos(t),
                          bm.sin(t) * bm.cos(t),
                          bm.cos(t) ** 2 - bm.sin(t) ** 2,
                          bm.cos(t) ** 2 - bm.sin(t) ** 2])
delay = bm.TimeDelay(f(0.), tau, before_t0=f, dt=dt)

A = bm.asarray([[0, 0, 1, 0], [0, 0, 0, 1], [0, -2, 0, 0], [-2, 0, 0, 0]])
B = bm.asarray([[0, 0, 0, 0], [0, 0, 0, 0], [-2, 0, 0, 0], [0, -2, 0, 0]])
c = bm.asarray([0, 0, 0, 0])
derivative = lambda x, t: A @ x + B @ delay(t - tau) + c

integral = bp.fde.GLShortMemory(derivative,
                                alpha=0.4,
                                inits=[f(0.)],
                                num_memory=500,
                                state_delays={'x': delay})
runner = bp.IntegratorRunner(integral,
                             inits=[f(0.)],
                             monitors=['x'],
                             dyn_vars=delay.vars(),
                             dt=dt)
runner.run(100.)
plt.plot(runner.mon.x[:, 0], runner.mon.x[:, 2])
plt.xlabel('x0')
plt.ylabel('x2')
plt.show()
_images/b736737ae09bb38b42e4a6850a650315a6d918a792d5c3258fd261e426dbac44.png

Acknowledgement#

This tutorial is highly inspired from the work of Ettore Messina [1] and of Qingyu Qu [2].

Joint Differential Equations#

@Xiaoyu Chen

In a dynamical system, there may be multiple variables that change dynamically over time. Sometimes these variables are interconnected, and updating one variable requires others as the input. For example, in the widely known Hodgkin–Huxley model, the variables \(V\), \(m\), \(h\), and \(n\) are updated synchronously and interdependently (please refer to Building Neuron Modelsfor details). To achieve higher integral accuracy, it is recommended to use brainpy.JointEq to jointly solving interconnected differential equations.

import brainpy as bp

brainpy.JointEq#

brainpy.JointEq is used to merge individual but interconnected differential equations into a single joint equation. For example, below are the two differential equations of the Izhikevich model:

a, b = 0.02, 0.20
dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
du = lambda u, t, V: a * (b * V - u)

Where updating \(V\) requires \(u\) as the input, and updating \(u\) requires \(V\) as the input. The joint equation can be defined as:

joint_eq = bp.JointEq(dV, du)

brainpy.JointEq receives only one argument named eqs, which can be a list or tuple containing multiple differential equations. Then it can be packed into a numarical integrator that solves the equation with a specified method, just as what can be done to any individual differential equation.

itg = bp.odeint(joint_eq, method='rk2')

There are several requirements for defining a joint equation:

  1. Every individual differential equation should follow the format of defining a ODE or SDE funtion in BrainPy. For example, the arguments before t denote the dynamical variables and arguments after t denote the parameters.

  2. The same variable in different equations should have the same name. Different variables should named differently.

Note that brainpy.JointEq supports make nested JointEq, which means the instance of JointEq can be an element to compose a new JointEq.

Why use brainpy.JointEq?#

Users may be confused with the function of brainpy.JointEq, because multiple differential equations can be written in a single function:

def diff(V, u, t, Iext):
    dV = 0.04 * V * V + 5 * V + 140 - u + Iext
    du = a * (b * V - u)
    return dV, du

itg_V_u = bp.odeint(diff, method='rk2')

or simply packed into interators separately:

int_V = bp.odeint(dV, method='rk2')
int_u = bp.odeint(du, method='rk2')

To illusrate the difference between joint and separate differential equations, let’s dive into the differential codes of these two types of equations.

If we make numerical solver for each derivative function, they will be solved independently:

import brainpy as bp
bp.odeint(dV, method='rk2', show_code=True)
def brainpy_itg_of_ode6(V, t, u, Iext, dt=0.1):
  dV_k1 = f(V, t, u, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  return V_new

{'f': <function <lambda> at 0x0000022948DD6A60>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x229660543a0>

As is shown in the output code, the variable \(V\) is integrated twice by the RK2 method. For the second differential value dV_k2, the updated value of \(V\) (k2_V_arg) and original \(u\) are used to calculate the differential value. This will generate a tiny error, since the values of \(V\) and \(u\) are taken at different times.

To eliminate this error, the differential equation of \(V\) and \(u\) should be solved jointly through brainpy.JointEq:

eq = bp.JointEq(eqs=(dV, du))
bp.odeint(eq, method='rk2', show_code=True)
def brainpy_itg_of_ode12_joint_eq(V, u, t, Iext, dt=0.1):
  dV_k1, du_k1 = f(V, u, t, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_u_arg = u + dt * du_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
  return V_new, u_new

{'f': <brainpy.integrators.joint_eq.JointEq object at 0x0000022967EC0C40>}
<brainpy.integrators.ode.explicit_rk.RK2 at 0x22967ec0160>

It is shown in this output code that second differential values of \(v\) and \(u\) are calculated by using the updated values (k2_V_arg and k2_u_arg) at the same time. This will result in a more accurate integral.

The figure below compares the simulation results of the Izhikevich model using joint and separate differential equations (\(dt = 0.2 ms\)). It is shown that as the simulation time increases, the integral error becomes greater.

Synaptic Connections#

@Tianqiu Zhang @Xiaoyu Chen

Synaptic connections is an essential part for building a neural dynamic system. BrainPy provides several commonly used connection methods in the brainpy.connect module (which can be accessed by the shortcut bp.conn) that can help users to easily construct many types of synaptic connection, inclulding built-in and self-customized connectors.

An Overview of BrainPy Connectors#

Here we provide an overview of BrainPy connectors.

Base class: bp.conn.Connector#

The base class of connectors is brainpy.connect.Connector. All connectors, built-in or customized, should inherit from the Connector class.

Two subclasses: TwoEndConnector and OneEndConnector#

There are two classes inheriting from the base class bp.conn.Connector:

  • bp.conn.TwoEndConnector: a connector to build synaptic connections between two neuron groups.

  • bp.conn.OneEndConnector: a connector to build synaptic connections within a population of neurons.

Users can click the link of each class above to look through the API documentation.

Connector.__init__()#

All connectors need to be initialized first. For each built-in connector, users need to pass in the corresponding parameters for initialization. For details, please see the specific conector type below.

Connector.__call__()#

After initialization, users should call the connector and pass in parameters depending on specific connection types:

  • TwoEndConnector: It has two input parameters pre_size and post_size, each representing the size of the pre- and post-synaptic neuron group. It will result in a connection matrix with the shape of (pre_num, post_num).

  • OneEndConnector: It has only one parameter pre_size which represent the size of the neuron group. It will result in a connection matrix with the shape of (pre_num, pre_num).

The __call__ function returns the class itself.

Connector.build_conn()#

Users can customize the connection in build_conn() function. Notice there are three connection types users can provide:

Connection Types

Definition

‘mat’

Dense conncetion, including a connection matrix.

‘ij’

Index projection, including a pre-neuron index vector and a post-neuron index vector.

‘csr’

Sparse connection, including a index vector and a indptr vector.

Return type can be either a dict or a tuple. Here are two examples of how to return your connection data:

Example 1:

def build_conn(self):
  ind = np.arange(self.pre_num)
  indptr = np.arange(self.pre_num + 1)

  return dict(csr=(ind, indptr), mat=None, ij=None)

Example 2:

def build_conn(self):
  ind = np.arange(self.pre_num)
  indptr = np.arange(self.pre_num + 1)

  return 'csr', (ind, indptr)

After creating the synaptic connection, users can use the require() method to access some useful properties of the connection.

Connector.require()#

This method returns the connection properties required by users. The connection properties are elaborated in the following sections in detail. Here is a brief summary of the connection properties users can require.

Connection properties

Structure

Definition

conn_mat

2-D array (matrix)

Dense connection matrix

pre_ids

1-D array (vector)

Indices of the pre-synaptic neuron group

post_ids

1-D array (vector)

Indices of the post-synaptic neuron group

pre2post

tuple (vector, vector)

The post-synaptic neuron indices and the corresponding pre-synaptic neuron pointers

post2pre

tuple (vector, vector)

The pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers

pre2syn

tuple (vector, vector)

The synapse indices sorted by pre-synaptic neurons and corresponding pre-synaptic neuron pointers

post2syn

tuple (vector, vector)

The synapse indices sorted by post-synaptic neurons and corresponding post-synaptic neuron pointers

Users can implement this method by following sentence:

pre_ids, post_ids, pre2post, conn_mat = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

Note

Note that this method can return multiple connection properties.

Connection Properties#

There are multiple connection properties that can be required by users.

1. conn_mat#

The matrix-based synaptic connection is one of the most intuitive ways to build synaptic computations. The connection matrix between two neuron groups can be easily obtained through the function of connector.requires('conn_mat'). Each connection matrix is an array with the shape of \((n_{pre}, n_{post})\):

2. pre_ids and post_ids#

Using vectors to store the connection between neuron groups is a much more efficient way to reduce memory when the connection matrix is sparse. For the connction matrix conn_mat defined above, we can align the connected pre-synaptic neurons and the post-synaptic neurons by two one-dimensional arrays: pre_ids and post_ids.

In this way, we only need two vectors (pre_ids and post_ids) to store the synaptic connection. syn_id in the figure indicates the indices of each neuron pair, i.e. each synapse.

3. pre2post and post2pre#

Another two synaptic structures are pre2post and post2pre. They establish the mapping between the pre- and post-synaptic neurons.

pre2post is a tuple containing two vectors, one of which is the post-synaptic neuron indices and the other is the corresponding pre-synaptic neuron pointers. For example, the following figure shows the indices of the pre-synaptic neurons and the post-synaptic neurons to which the pre-synaptic neurons project:

To record the connection, firstly the post_ids are concatenated as a single vector call the post-synaptic index vector (indices). Because the post-synaptic neuron indices have been sorted by the pre-synaptic neuron indices, it is sufficient to record only the starting position of each pre-synaptic neuron index. Therefore, the pre-synaptic neuron indices and the end of the last pre-synaptic neuron index together make up the pre-synaptic index pointer vector (indptr), which is illustrated in the figure below.

The post-synaptic neuron indices to which pre-synaptic neuron \(i\) projects can be obtained by array slicing:

indices[indptr[i], indptr[i+1]]

Similarly, post2pre is a 2-element tuple containing the pre-synaptic neuron indices and the corresponding post-synaptic neuron pointers. Taking the connection in the illutration aobve as an example, the post-synaptic neuron indices and the pre-synaptic neuron indices to which the post-synaptic neurons project is shown as:

The pre-synaptic index vector (indices) and the post-synaptic index pointer vector (indptr) are listed below:

When the connection is sparse, pre2post (or post2pre) is a very efficient way to store the connection, since the lengths of the two vectors in the tuple are \(n_{synapse}\) and \(n_{pre}\) (\(n_{post}\)), respectively.

4. pre2syn and post2syn#

The last two properties are pre2syn and post2syn that record pre- and post-synaptic projection, respectively.

For pre2syn, similar to pre2post and post2pre, there is a synapse index vector and a pre-synaptic index pointer vector that refers to the starting position of each pre-synaptic neuron index at the synapse index vector.

Below is the same example identifying the connection by pre-synaptic neuron indices and the synapses belonging to them.

For better understanding, The synapse indices, pre- and post-synaptic neuron indices are shown as below:

The pre-synaptic index pointer vector is computed in the same way as in pre2post:

Similarly, post2syn is a also tuple containing the synapse neuron indices and the corresponding post-synaptic neuron pointers.

The only different from pre2syn is that the synapse indices is (most of the time) originally sorted by pre-synaptic neurons, but when computing post2syn, synapses should be sorted by post-synaptic neuron indices:

The synapse index vector (the first row) and the post-synaptic index pointer vector (the last row) are listed below:

import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')

bp.__version__
'2.3.0'
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

Built-in regular connections#

brainpy.connect.One2One#

The neurons in the pre-synaptic neuron group only connect to the neurons in the same position of the post-synaptic group. Thus, this connection requires the indices of two neuron groups same. Otherwise, an error will occurs.

conn = bp.connect.One2One()
conn(pre_size=10, post_size=10)
One2One

where pre_size denotes the size of the pre-synaptic neuron group, post_size denotes the size of the post-synaptic neuron group. Note that parameter size can be int, tuple of int or list of int where each element represent each dimension of neuron group.

In One2One connection, particularly, pre_size and post_size must be the same.

Class One2One is inherited from TwoEndConnector. Users can use method require or requires to get specific connection properties.

Here is an example:

size = 5
conn = bp.connect.One2One()(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
pre_ids: Array([0, 1, 2, 3, 4], dtype=int32)
post_ids: Array([0, 1, 2, 3, 4], dtype=int32)
pre2post: (Array([0, 1, 2, 3, 4], dtype=int32), Array([0, 1, 2, 3, 4, 5], dtype=int32))

brainpy.connect.All2All#

All neurons of the post-synaptic population form connections with all neurons of the pre-synaptic population (dense connectivity). Users can choose whether connect the neurons at the same position (include_self=True or False).

conn = bp.connect.All2All(include_self=False)
conn(pre_size=size, post_size=size)
All2All(include_self=False)

Class All2All is inherited from TwoEndConnector. Users can use method require or requires to get specific connection properties.

Here is an example:

conn = bp.connect.All2All(include_self=False)(pre_size=size, post_size=size)
res = conn.require('pre_ids', 'post_ids', 'pre2post', 'conn_mat')

print('pre_ids:', res[0])
print('post_ids:', res[1])
print('pre2post:', res[2])
print('conn_mat:', res[3])
pre_ids: Array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], dtype=int32)
post_ids: Array([1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], dtype=int32)
pre2post: (Array([1, 2, 3, 4, 0, 2, 3, 4, 4, 3, 1, 0, 0, 1, 2, 4, 0, 1, 2, 3], dtype=int32), Array([ 0,  4,  8, 12, 16, 20], dtype=int32))
conn_mat: Array([[False,  True,  True,  True,  True],
       [ True, False,  True,  True,  True],
       [ True,  True, False,  True,  True],
       [ True,  True,  True, False,  True],
       [ True,  True,  True,  True, False]], dtype=bool)

brainpy.connect.GridFour#

GridFour is the four nearest neighbors connection. Each neuron connect to its nearest four neurons.

conn = bp.connect.GridFour(include_self=False)
conn(pre_size=size)
GridFour(include_self=False, periodic_boundary=False)

Class GridFour is inherited from OneEndConnector, therefore there is only one parameter pre_size representing the size of neuron group, which should be two-dimensional geometry.

Here is an example:

size = (4, 4)
conn = bp.connect.GridFour(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')

print('pre_ids', res[0])
pre_ids Array([ 0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  4,  4,  4,  5,  5,  5,  5,
        6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9,  9, 10, 10, 10,
       10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15],      dtype=int32)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()
_images/2a6de7994b1e94d340744e7851191a6707316b5eddfe1a0e4df050c1f41aae6e.png

brainpy.connect.GridEight#

GridEight is eight nearest neighbors connection. Each neuron connect to its nearest eight neurons.

conn = bp.connect.GridEight(include_self=False)
conn(pre_size=size)
GridEight(N=1, include_self=False, periodic_boundary=False)

Class GridEight is inherited from GridN, which will be introduced as followed.

Here is an example:

size = (4, 4)
conn = bp.connect.GridEight(include_self=False)(pre_size=size)
res = conn.require('pre_ids', 'conn_mat')

print('pre_ids', res[0])
pre_ids Array([ 0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,  4,
        4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,
        6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,
        9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11,
       12, 12, 12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15],      dtype=int32)

Take the central point (id = 4) as an example, its neighbors are all the other point except itself. Therefore, its row in conn_mat has True for all values except itself.

# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res[1])
nx.draw(G, with_labels=True)
plt.show()
_images/15abb30f98db0310deac8c9e4c6595872f9ec319bc0b09bb1ff24d2112412d62.png

brainpy.connect.GridN#

GridN is also a nearest neighbors connection. Each neuron connect to its nearest \((2N+1) \cdot (2N+1)\) neurons (if including itself).

Here are some examples to fully understand GridN. It is slightly different from GridEight: GridEight is equivalent to GridN when N = 1.

  • When N = 1: \(\begin{bmatrix} x & x & x\\ x & I & x\\ x & x & x \end{bmatrix}\)

  • When N = 2: \( \begin{bmatrix} x & x & x & x & x\\ x & x & x & x & x\\ x & x & I & x & x\\ x & x & x & x & x\\ x & x & x & x & x \end{bmatrix} \)

conn = bp.connect.GridN(N=2, include_self=False)
conn(pre_size=size)
GridN(N=2, include_self=False, periodic_boundary=False)

Here is an example:

size = (4, 4)
conn = bp.connect.GridN(N=1, include_self=False)(pre_size=size)
res = conn.require('conn_mat')
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(res)
nx.draw(G, with_labels=True)
plt.show()
_images/698eb1551797eee4ad3d2b564f4883f672c8fe8d8ce46738d59a971bb9ea2bcf.png

Built-in random connections#

brainpy.connect.FixedProb#

For each post-synaptic neuron, there is a fixed probability that it forms a connection with a neuron of the pre-synaptic population. It is basically a all_to_all projection, except some synapses are not created, making the projection sparser.

Class brainpy.connect.FixedProb is inherited from TwoEndConnector, and it receives three settings:

  • prob: Fixed probability for connection with a pre-synaptic neuron for each post-synaptic neuron.

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedProb(prob=0.5, include_self=False, seed=134)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[False,  True,  True,  True],
       [False, False,  True,  True],
       [False,  True, False, False],
       [ True, False, False, False]], dtype=bool)

brainpy.connect.FixedPreNum#

Each neuron in the post-synaptic population receives connections from a fixed number of neurons of the pre-synaptic population chosen randomly. It may happen that two post-synaptic neurons are connected to the same pre-synaptic neuron and that some pre-synaptic neurons are connected to nothing.

Class brainpy.connect.FixedPreNum is inherited from TwoEndConnector, and it receives three settings:

  • num: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedPreNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[ True,  True, False, False],
       [False,  True, False,  True],
       [ True, False,  True, False],
       [False, False,  True,  True]], dtype=bool)

brainpy.connect.FixedPostNum#

Each neuron in the pre-synaptic population sends a connection to a fixed number of neurons of the post-synaptic population chosen randomly. It may happen that two pre-synaptic neurons are connected to the same post-synaptic neuron and that some post-synaptic neurons receive no connection at all.

Class brainpy.connect.FixedPostNum is inherited from TwoEndConnector, and it receives three settings:

  • num: The conn probability (if “num” is float) or the fixed number of connectivity (if “num” is int).

  • include_self: Whether connect to inself.

  • seed: Seed the random generator.

And there are two parameters passed in for calling instance of class: pre_size and post_size.

conn = bp.connect.FixedPostNum(num=2, include_self=True, seed=1234)
conn(pre_size=4, post_size=4)
conn.require('conn_mat')
Array([[ True, False,  True, False],
       [ True,  True, False, False],
       [False, False,  True,  True],
       [False,  True, False,  True]], dtype=bool)

brainpy.connect.GaussianProb#

Builds a Gaussian connection pattern between the two populations, where the connection probability decay according to the gaussian function.

Specifically,

\[ p=\exp\left(-\frac{(x-x_c)^2+(y-y_c)^2}{2\sigma^2}\right) \]

where \((x, y)\) is the position of the pre-synaptic neuron and \((x_c,y_c)\) is the position of the post-synaptic neuron.

For example, in a \(30 \textrm{x} 30\) two-dimensional networks, when \(\beta = \frac{1}{2\sigma^2} = 0.1\), the connection pattern is shown as the follows:

GaussianProb is inherited from OneEndConnector, and it receives four settings:

  • sigma: (float) Width of the Gaussian function.

  • encoding_values: (optional, list, tuple, int, float) The value ranges to encode for neurons at each axis.

  • periodic_boundary : (bool) Whether the neuron encode the value space with the periodic boundary.

  • normalize: (bool) Whether normalize the connection probability.

  • include_self : (bool) Whether create the conn at the same position.

  • seed: (bool) The random seed.

conn = bp.connect.GaussianProb(sigma=2, periodic_boundary=True, normalize=True, include_self=True, seed=21)
conn(pre_size=10)
conn.require('conn_mat')
Array([[ True,  True, False,  True, False, False, False, False,  True,
         True],
       [ True,  True,  True,  True, False, False, False, False, False,
         True],
       [ True,  True,  True,  True, False, False, False, False, False,
         True],
       [False,  True,  True,  True,  True, False, False, False, False,
        False],
       [False, False, False,  True,  True,  True, False, False, False,
        False],
       [False, False,  True, False,  True,  True,  True,  True, False,
        False],
       [False, False, False,  True, False,  True,  True,  True, False,
        False],
       [ True, False, False, False, False,  True,  True,  True, False,
         True],
       [False,  True, False, False, False, False,  True,  True,  True,
         True],
       [ True, False,  True, False, False,  True, False,  True,  True,
         True]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/8dbec3352cedf208c06b5f9dc68a41e5523c6d056f7620987f2c80dd880314fa.png

brainpy.connect.SmallWorld#

SmallWorld is a connector class to help build a small-world network [1]. small-world network is defined to be a network where the typical distance L between two randomly chosen nodes (the number of steps required) grows proportionally to the logarithm of the number of nodes N in the network, that is:

\[ L\propto \log N \]

[1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, Nature, 393, pp. 440–442, 1998.

Currently, SmallWorld only support a one-dimensional network with the ring structure. It receives four settings:

  • num_neighbor: the number of the nearest neighbors to connect.

  • prob: the probability of rewiring each edge.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • include_self: whether allow to connect to itself.

conn = bp.connect.SmallWorld(num_neighbor=5, prob=0.2, directed=False, include_self=False)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False,  True,  True, False, False, False, False, False,  True,
         True],
       [ True, False,  True, False, False, False,  True, False, False,
         True],
       [ True,  True, False,  True,  True, False,  True, False, False,
        False],
       [False, False,  True, False,  True,  True,  True, False, False,
        False],
       [False, False,  True,  True, False, False,  True, False, False,
        False],
       [False, False, False,  True, False, False,  True,  True, False,
        False],
       [False,  True,  True,  True,  True,  True, False, False,  True,
        False],
       [False, False, False, False, False,  True, False, False,  True,
         True],
       [ True, False, False, False, False, False,  True,  True, False,
         True],
       [ True,  True, False, False, False, False, False,  True,  True,
        False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/53f25a302318e18bc96602c4a5694ebb0b81e1ace14b0e42bdf6129aedf0cdac.png

brainpy.connect.ScaleFreeBA#

ScaleFreeBA is a connector class to help build a random scale-free network according to the Barabási–Albert preferential attachment model [2]. ScaleFreeBA receives the following settings:

  • m: Number of edges to attach from a new node to existing nodes.

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[2] A. L. Barabási and R. Albert “Emergence of scaling in random networks”, Science 286, pp 509-512, 1999.

conn = bp.connect.ScaleFreeBA(m=5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False, False, False,  True,  True, False, False,
        False],
       [False, False, False, False, False,  True,  True,  True, False,
         True],
       [False, False, False, False, False,  True,  True,  True,  True,
        False],
       [False, False, False, False, False,  True,  True,  True,  True,
         True],
       [False, False, False, False, False,  True, False, False, False,
        False],
       [ True,  True,  True,  True,  True, False,  True,  True,  True,
         True],
       [ True,  True,  True,  True, False,  True, False,  True,  True,
         True],
       [False,  True,  True,  True, False,  True,  True, False,  True,
         True],
       [False, False,  True,  True, False,  True,  True,  True, False,
        False],
       [False,  True, False,  True, False,  True,  True,  True, False,
        False]], dtype=bool)
# Using NetworkX to visualize network connection
G = nx.from_numpy_matrix(conn.require('conn_mat'))
nx.draw(G, with_labels=True)
plt.show()
_images/bc4958a5d87d48af7bf66b8957c1d621b46ac61635ee9bf76f5ff12032c52706.png

brainpy.connect.ScaleFreeBADual#

ScaleFreeBADual is a connector class to help build a random scale-free network according to the dual Barabási–Albert preferential attachment model [3]. ScaleFreeBA receives the following settings:

  • p: The probability of attaching \(m_1\) edges (as opposed to \(m_2\) edges).

  • m1 : Number of edges to attach from a new node to existing nodes with probability \(p\).

  • m2: Number of edges to attach from a new node to existing nodes with probability \(1-p\).

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed: Indicator of random number generation state.

[3] N. Moshiri. “The dual-Barabasi-Albert model”, arXiv:1810.10538.

conn = bp.connect.ScaleFreeBADual(m1=3, m2=5, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False, False, False,  True,  True,  True, False,
         True],
       [False, False, False, False, False,  True, False,  True,  True,
        False],
       [False, False, False, False, False,  True,  True,  True,  True,
         True],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [ True,  True,  True, False, False, False,  True,  True,  True,
         True],
       [ True, False,  True, False, False,  True, False,  True,  True,
         True],
       [ True,  True,  True, False, False,  True,  True, False,  True,
         True],
       [False,  True,  True, False, False,  True,  True,  True, False,
        False],
       [ True, False,  True, False, False,  True,  True,  True, False,
        False]], dtype=bool)

brainpy.connect.PowerLaw#

PowerLaw is a connector class to help build a random graph with powerlaw degree distribution and approximate average clustering [4]. It receives the following settings:

  • m : the number of random edges to add for each new node

  • p : Probability of adding a triangle after adding a random edge

  • directed: whether the edge is the directed (“directed=True”) or undirected (“directed=False”) connection.

  • seed : Indicator of random number generation state.

[4] P. Holme and B. J. Kim, “Growing scale-free networks with tunable clustering”, Phys. Rev. E, 65, 026107, 2002.

conn = bp.connect.PowerLaw(m=3, p=0.5, directed=False, seed=12345)
conn(pre_size=10, post_size=10)
conn.require('conn_mat')
Array([[False, False, False,  True,  True, False,  True,  True,  True,
        False],
       [False, False, False,  True, False,  True, False, False, False,
        False],
       [False, False, False,  True,  True,  True,  True,  True, False,
         True],
       [ True,  True,  True, False,  True, False, False, False, False,
        False],
       [ True, False,  True,  True, False,  True, False, False,  True,
        False],
       [False,  True,  True, False,  True, False,  True,  True, False,
         True],
       [ True, False,  True, False, False,  True, False, False,  True,
        False],
       [ True, False,  True, False, False,  True, False, False, False,
         True],
       [ True, False, False, False,  True, False,  True, False, False,
        False],
       [False, False,  True, False, False,  True, False,  True, False,
        False]], dtype=bool)

Encapsulate your existing connections#

BrainPy also allows users to encapsulate existing connections with convenient class interfaces. Users can provide connection types as:

  • Index projection;

  • Dense matrix;

  • Sparse matrix.

Then users should provide pre_size and post_size information in order to instantiate the connection. In such a way, based on the following connection classes, users can generate any other synaptic structures (such like pre2post, pre2syn, conn_mat, etc.) easily.

bp.conn.IJConn#

Here, let’s take a simple connection as an example. In this example, we create a connection which receives users’ handful index projection by using bp.conn.IJConn.

pre_list = np.array([0, 1, 2])
post_list = np.array(([0, 0, 0]))
conn = bp.conn.IJConn(i=pre_list, j=post_list)
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
Array([[ True, False, False],
       [ True, False, False],
       [ True, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)
conn.requires('pre2post')
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\ops\scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=uint32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
(Array([0, 0, 0], dtype=int32), Array([0, 1, 2, 3, 3, 3], dtype=int32))
conn.requires('pre2syn')
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2, 3, 3, 3], dtype=int32))

bp.conn.MatConn#

In next example, we create a connection which receives user’s handful dense connection matrix by using bp.conn.MatConn.

bp.math.random.seed(123)
conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_))
conn = conn(pre_size=5, post_size=3)
conn.requires('conn_mat')
Array([[False,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True],
       [False,  True,  True],
       [False, False,  True]], dtype=bool)
conn.requires('pre2post')
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\ops\scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=uint32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
(Array([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 2], dtype=int32),
 Array([ 0,  2,  5,  8, 10, 11], dtype=int32))
conn.require('pre2syn')
(Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32),
 Array([ 0,  2,  5,  8, 10, 11], dtype=int32))

bp.conn.SparseMatConn#

In last example, we create a connection which receives user’s handful sparse connection matrix by using bp.conn.sparseMatConn

from scipy.sparse import csr_matrix

conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
sparse_mat = csr_matrix(conn_mat)
conn = bp.conn.SparseMatConn(sparse_mat)
conn = conn(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1])
conn.requires('conn_mat')
Array([[ True, False,  True],
       [ True, False,  True],
       [ True, False,  True],
       [False,  True,  True],
       [ True,  True, False]], dtype=bool)
conn.requires('pre2post')
(Array([0, 2, 0, 2, 0, 2, 1, 2, 0, 1], dtype=int32),
 Array([ 0,  2,  4,  6,  8, 10], dtype=int32))
conn.requires('post2syn')
(Array([0, 2, 4, 8, 6, 9, 1, 3, 5, 7], dtype=int32),
 Array([ 0,  4,  6, 10], dtype=int32))

Using NetworkX to provide connections and pass into Connector#

NetworkX is a Python package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks.

Users can design their own complex netork by using NetworkX.

import networkx as nx
G = nx.Graph()

By definition, a Graph is a collection of nodes (vertices) along with identified pairs of nodes (called edges, links, etc).

To learn more about NetowrkX, please check the official documentation: NetworkX tutorial

Using class brainpy.connect.MatConn to construct connections is recommended here.

  • Dense adjacency matrix: a two-dimensional ndarray.

Here gives an example to illustrate how to transform a random graph into your synaptic connections by using dense adjacency matrix.

G = nx.fast_gnp_random_graph(5, 0.5) # initialize a random graph G
B = nx.adjacency_matrix(G)
A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G

print('dense adjacency matrix:')
print(A)
nx.draw(G, with_labels=True)
plt.show()
dense adjacency matrix:
[[0 1 1 0 1]
 [1 0 1 0 0]
 [1 1 0 1 0]
 [0 0 1 0 1]
 [1 0 0 1 0]]
C:\Users\adadu\AppData\Local\Temp\ipykernel_2488\2020588093.py:2: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
  B = nx.adjacency_matrix(G)
C:\Users\adadu\AppData\Local\Temp\ipykernel_2488\2020588093.py:3: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
  A = np.array(nx.adjacency_matrix(G).todense()) # get dense adjacency matrix of G
_images/7a18a2d387a1002ef515773d222007f4770e7b54286459440faf40799532e7f7.png

Users can use class MatConn inherited from TwoEndConnector to construct connections. A dense adjacency matrix should be passed in when initializing MatConn class. Note that when calling the instance of the class, users should pass in two parameters: pre_size and post_size. In this case, users can use the shape of dense adjacency matrix as the parameters.

conn = bp.connect.MatConn(A)(pre_size=A.shape[0], post_size=A.shape[1])
res = conn.require('conn_mat')

print(res)
Array([[False,  True,  True, False,  True],
       [ True, False,  True, False, False],
       [ True,  True, False,  True, False],
       [False, False,  True, False,  True],
       [ True, False, False,  True, False]], dtype=bool)

Customize your connections#

BrainPy allows users to customize their connections. The following requirements should be satisfied:

  • Your connection class should inherit from brainpy.connect.TwoEndConnector or brainpy.connect.OneEndConnector.

  • __init__ function should be implemented and essential parameters should be initialized.

  • Users should also overwrite build_csr(), build_coo() or build_mat() function to describe how to build your connection.

Let’s take an example to illustrate the details of customization.

class FixedProb(bp.connect.TwoEndConnector):
  """Connect the post-synaptic neurons with fixed probability.

  Parameters
  ----------
  prob : float
      The conn probability.
  include_self : bool
      Whether to create (i, i) connection.
  seed : optional, int
      Seed the random generator.
  """

  def __init__(self, prob, include_self=True, seed=None):
    super(FixedProb, self).__init__()
    assert 0. <= prob <= 1.
    self.prob = prob
    self.include_self = include_self
    self.seed = seed
    self.rng = np.random.RandomState(seed=seed)

  def build_csr(self):
    ind = []
    count = np.zeros(self.pre_num, dtype=np.uint32)

    def _random_prob_conn(rng, pre_i, num_post, prob, include_self):
      p = rng.random(num_post) <= prob
      if (not include_self) and pre_i < num_post:
        p[pre_i] = False
      conn_j = np.asarray(np.where(p)[0], dtype=np.uint32)
      return conn_j

    for i in range(self.pre_num):
      posts = _random_prob_conn(self.rng, pre_i=i, num_post=self.post_num,
                                prob=self.prob, include_self=self.include_self)
      ind.append(posts)
      count[i] = len(posts)

    ind = np.concatenate(ind)
    indptr = np.concatenate(([0], count)).cumsum()

    return ind, indptr

Then users can initialize the your own connections as below:

conn = FixedProb(prob=0.5, include_self=True)(pre_size=5, post_size=5)
conn.require('conn_mat')
Array([[False,  True, False,  True,  True],
       [ True, False,  True,  True,  True],
       [False,  True, False, False, False],
       [ True, False, False,  True, False],
       [ True,  True, False, False, False]], dtype=bool)

Synaptic Weights#

@Xiaoyu Chen

In a brain model, synaptic weights, the strength of the connection between presynaptic and postsynaptic neurons, are crucial to the dynamics of the model. In this section, we will illutrate how to build synaptic weights in a synapse model.

import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt

bp.math.set_platform('cpu')
bp.__version__
'2.3.0'

Creating Static Weights#

Some computational models focus on the network structure and its influence on network dynamics, thus not modeling neural plasticity for simplicity. In this condition, synaptic weights are fixed and do not change in simulation. They can be stored as a scalar, a matrix or a vector depending on the connection strength and density.

1. Storing weights with a scalar#

If all synaptic weights are designed to be the same, the single weight value can be stored as a scalar in the synpase model to save memory space.

weight = 1.

The weight can be stored in a synapse model. When updating the synapse, this weight is assigned to all synapses by scalar multiplication.

2. Storing weights with a matrix#

When the synaptic connection is dense and the synapses are assigned with different weights, weights can be stored in a matrix \(W\), where \(W(i, j)\) refers to the weight of presynaptic neuron \(i\) to postsynaptic neuron \(j\).

BrainPy provides brainpy.initialize.Initializer (or brainpy.init for short) for weight initialization as a matrix. The tutorial of brainpy.init.Initializer is introduced later.

For example, a weight matrix can be constructed using brainpy.init.Uniform, which initializes weights with a random distribution:

pre_size = (4, 4)
post_size = (3, 3)

uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init((pre_size, post_size))
print('shape of weights: {}'.format(weights.shape))
shape of weights: (16, 9)

Then, the weights can be assigned to a group of connections with the same shape. For example, an all-to-all connection matrix can be obtained by brainpy.conn.All2All() whose tutorial is contained in Synaptic Connections:

conn = bp.conn.All2All()
conn(pre_size, post_size)
conn_mat = conn.requires('conn_mat')  # request the connection matrix

Therefore, weights[i, j] refers to the weight of connection (i, j).

i, j = (2, 3)
print('whether (i, j) is connected: {}'.format(conn_mat[i, j]))
print('synaptic weights of (i, j): {}'.format(weights[i, j]))
whether (i, j) is connected: True
synaptic weights of (i, j): 0.974509060382843

3. Storing weights with a vector#

When the synaptic connection is sparse, using a matrix to store synaptic weights is too wasteful. Instead, the weights can be stored in a vector which has the same length as the synaptic connections.

Weights can be assigned to the corresponding synapses as long as the they are aligned with each other.

size = 5

conn = bp.conn.One2One()
conn(size, size)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')

print('presynaptic neuron ids: {}'.format(pre_ids))
print('postsynaptic neuron ids: {}'.format(post_ids))
print('synapse ids: {}'.format(bm.arange(size)))
presynaptic neuron ids: [0 1 2 3 4]
postsynaptic neuron ids: [0 1 2 3 4]
synapse ids: [0 1 2 3 4]

The weight vector is aligned with the synapse vector, i.e. synapse ids :

weights = bm.random.uniform(0, 2, size)

for i in range(size):
    print('weight of synapse {}: {}'.format(i, weights[i]))
weight of synapse 0: 1.1543986797332764
weight of synapse 1: 1.7815501689910889
weight of synapse 2: 0.6559045314788818
weight of synapse 3: 0.48931145668029785
weight of synapse 4: 0.19386005401611328
Conversion from a weight matrix to a weight vector#

For users who would like to obtain the weight vector from the weight matrix, they can first build a connection according to the non-zero elements in the weight matrix and then slice the weight matrix according to the connection:

weight_mat = np.array([[1., 1.5, 0., 0.5], [0., 2.5, 0., 0.], [2., 0., 3, 0.]])
print('weight matrix: \n{}'.format(weight_mat))

conn = bp.conn.MatConn(weight_mat)
pre_ids, post_ids = conn.requires('pre_ids', 'post_ids')

weight_vec = weight_mat[pre_ids, post_ids]
print('weight_vector: \n{}'.format(weight_vec))
weight matrix: 
[[1.  1.5 0.  0.5]
 [0.  2.5 0.  0. ]
 [2.  0.  3.  0. ]]
weight_vector: 
[1.  1.5 0.5 2.5 2.  3. ]

Note

However, it is not recommended to use this function when the connection is sparse and of a large scale, because generating the weight matrix will take up too much space.

Creating Dynamic Weights#

Sometimes users may want to realize neural plasticity in a brain model, which requires the synaptic weights to change during simulation. In this condition, weights should be considered as variables, thus defined as brainpy.math.Variable. If it is packed in a synapse model, weight updating should be realized in the update(_t, _dt) function of the synapse model.

weights = bm.Variable(bm.ones(10))
weights
Variable([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

Built-in Weight Initializers#

Base Class: bp.init.Initializer#

The base class of weight initializers are brainpy.initialize.Initializer, which can be accessed by the shortcut bp.init. All initializers, built-in or costumized, should inherit the Initializer class.

Weight initialization is implemented in the __call__ function, so that it can be realized automatically when the initializer is called. The __call__ function has a shape parameter that has different meanings in the following two superclasses and returns a weight matrix.

Superclass 1: bp.init.InterLayerInitializer#

The InterLayerInitializer is an abstract subclass of Initializer. Its subclasses initialize the weights between two fully connected layers. The shape parameter of the __call__ function should be a 2-element tuple \((m, n)\), which refers to the number of presynaptic neurons \(m\) and of postsynaptic neurons \(n\). The output of the __call__ function is a bp.math.ndarray with the shape of \((m, n)\), where the value at \((i, j)\) is the initialized weight of the presynaptic neuron \(i\) to postsynpatic neuron \(j\).

Superclass 2: bp.init.IntraLayerInitializer#

The IntraLayerInitializer is also an abstract subclass of Initializer. Its subclasses initialize the weights within a single layer. The shape parameter of the __call__ function refers to the the structure of the neural population \((n_1, n_2, ..., n_d)\). The __call__ function returns a 2-D bp.math.ndarray with the shape of \((\prod_{k=1}^d n_k, \prod_{k=1}^d n_k)\). In the 2-D array, the value at \((i, j)\) is the initialized weight of neuron \(i\) to neuron \(j\) of the flattened neural sequence.

1. Built-In Regular Initializers#

Regular initializers all belong to InterLayerInitializer and initialize the connection weights between two layers with a regular pattern. There are ZeroInit, OneInit, and Identity initializers in built-in regular initializers. Here we show how to use the OneInit initializer. The remaining two classes are used in a similar way.

# visualization
def mat_visualize(matrix, cmap=plt.cm.get_cmap('coolwarm')):
    im = plt.matshow(matrix, cmap=cmap)
    plt.colorbar(mappable=im, shrink=0.8, aspect=15)
    plt.show()
# 'OneInit' initializes all the weights with the same value
shape = (5, 6)
one_init = bp.init.OneInit(value=2.5)
weights = one_init(shape)
print(weights)
Array([[2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
       [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
       [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
       [2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
       [2.5, 2.5, 2.5, 2.5, 2.5, 2.5]], dtype=float32)

2. Built-In Random Initializers#

Random initializers all belong to InterLayerInitializer and initialize the connection weights between two layers with a random distribution. There are Normal, Uniform, Orthogonal and other initializers in built-in regular initializers. Here we show how to use the Normal and Uniform initializers.

bp.init.Normal

This initializer initializes the weights with a normal distribution. The variance of the distribution changes according to the scale parameter. In the following example, 10 presynaptic neurons are fully connected to 20 postsynaptic neurons with random weight values:

shape = (10, 20)
normal_init = bp.init.Normal(scale=1.0)
weights = normal_init(shape)
mat_visualize(weights)
_images/db79961e274558d1c84097e389ec7905fbb8693c1c73671da6677e395be35f0b.png

bp.init.Uniform

This initializer resembles brainpy.init.Normal but initializes the weights with a uniform distribution.

uniform_init = bp.init.Uniform(min_val=0., max_val=1.)
weights = uniform_init(shape)
mat_visualize(weights)
_images/404ef715cacdf61c8e01eb0e5d0a6046df7500705178fd7631bc60f1b14c40c9.png

3. Built-In Decay Initializers#

Decay initializers all belong to IntraLayerInitializer and initialize the connection weights within a layer with a decay function according to the neural distance. There are GaussianDecay and DOGDecay initializers in built-in decay initializers. Below are examples of how to use them.

brainpy.training.initialize.GaussianDecay

This initializer creates a Gaussian connectivity pattern within a population of neurons, where the weights decay with a gaussian function. Specifically, for any pair of neurons \( (i, j) \), the weight is computed as

\[ w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) \]

where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 5 \times 5 \). Note that this shape is the structure of the target neural population, not the size of presynaptic and postsynaptic neurons.

size = (5, 5)
gaussian_init = bp.init.GaussianDecay(sigma=2., max_w=10., include_self=True)
weights = gaussian_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (25, 25)

Self-connections are created if include_self=True. The connection weights of neuron \(i\) with others are stored in row \(i\) of weights. For instance, the connection weights of neuron(1, 2) to other neurons are stored in weights[7] (\(5 \times 1 +2 = 7\)). After reshaping, the weights are:

mat_visualize(weights[0].reshape(size), cmap=plt.cm.get_cmap('Reds'))
_images/7f74d8bc1ab7d0542930653c6236f7d067d260499c1fbb939003839517f23980.png

brainpy.training.initialize.DOGDecay

This initializer creates a Difference-Of-Gaussian (DOG) connectivity pattern within a population of neurons. Specifically, for the given pair of neurons \( (i, j) \), the weight between them is computed as

\[ w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2}) \]

where \( v_k^i \) is the \( i \)-th neuron’s encoded value (position) at dimension \( k \).

The example below is a neural population with the size of \( 10 \times 12 \):

size = (10, 12)
dog_init = bp.init.DOGDecay(sigmas=(1., 3.), max_ws=(10., 5.), min_w=0.1, include_self=True)
weights = dog_init(size)
print('shape of weights: {}'.format(weights.shape))
shape of weights: (120, 120)

Weights smaller than min_w will not be created. min_w \( = 0.005 \times min( \) max_ws \( ) \) if it is not assigned with a value. The organization of weights is similar to that in the GaussianDecay initializer. For instance, the connection weights of neuron (3, 4) to other neurons after reshaping are shown as below:

mat_visualize(weights[3*12+4].reshape(size), cmap=plt.cm.get_cmap('Reds'))
_images/1f82b70a5ed28ac9d71f0b2f1dddb5bcf158e2b3fcff2356b7568027b83f217f.png

Customizing your initializers#

BrainPy also allows users to customize the weight initializers of their own. When customizing a initializer, users should follow the instructions below:

  • Your initializer should inherit brainpy.initialize.Initializer.

  • Override the __call__ funtion, to which the shape parameter should be given.

Here is an example of creating an inter-layer initializer that initialize the weights as follows:

\[ w(i, j) = max(w_{max} - \sigma |v_i - v_j|, 0) \]
class LinearDecay(bp.init.InterLayerInitializer):
    def __init__(self, max_w, sigma=1.):
        self.max_w = max_w
        self.sigma = sigma
    
    def __call__(self, shape, dtype=None):
        mat = bp.math.zeros(shape, dtype=dtype)
        n_pre, n_post = shape
        seq = np.arange(n_pre)
        current_w = self.max_w
        
        for i in range(max(n_pre, n_post)):
            if current_w <= 0:
                break
            seq_plus = ((seq + i) >= 0) & ((seq + i) < n_post)
            seq_minus = ((seq - i) >= 0) & ((seq - i) < n_post)
            mat[seq[seq_plus], (seq + i)[seq_plus]] = current_w
            mat[seq[seq_minus], (seq - i)[seq_minus]] = current_w
            current_w -= self.sigma
        
        return mat
shape = (10, 15)
lin_init = LinearDecay(max_w=5, sigma=1.)
weights = lin_init(shape)
mat_visualize(weights, cmap=plt.cm.get_cmap('Reds'))
_images/a8f43de1bcd910467b54e01296a8179aee0a31263252ee8d5436b52a005b44dd.png

Note

Note that customized initializers, or brainpy.init.Initializer, is not limited to returning a matrix. Although currently all the built-in initializers use matrix to store weights, they can also be designed to return a vector to store synaptic weights.

Gradient Descent Optimizers#

@Chaoming Wang @Xiaoyu Chen

Gradient descent is one of the most popular optimization methods. At present, gradient descent optimizers, combined with the loss function, are the key to machine learning, especially deep learning. In this section, we are going to understand:

  • how to use optimizers in BrainPy?

  • how to customize your own optimizer?

import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')
bp.__version__
'2.3.0'
import matplotlib.pyplot as plt

Optimizers in BrainPy#

The basic optimizer class in BrainPy is brainpy.optimizers.Optimizer, which inludes the following optimizers:

  • SGD

  • Momentum

  • Nesterov momentum

  • Adagrad

  • Adadelta

  • RMSProp

  • Adam

All supported optimizers can be inspected through the brainpy.math.optimizers APIs.

Generally, an optimizer initialization receives the learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

  • lr can be a float, or an instance of brainpy.optim.Scheduler.

  • train_vars should be a dict of Variable.

Here we launch a SGD optimizer.

a = bm.Variable(bm.ones((5, 4)))
b = bm.Variable(bm.zeros((3, 3)))

op = bp.optim.SGD(lr=0.001, train_vars={'a': a, 'b': b})

When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)
a: Variable([[0.9993626 , 0.9997406 , 0.999853  , 0.999312  ],
          [0.9993036 , 0.99934477, 0.9998294 , 0.9997739 ],
          [0.99900717, 0.9997449 , 0.99976104, 0.99953616],
          [0.9995185 , 0.99917144, 0.9990044 , 0.99914813],
          [0.9997468 , 0.9999408 , 0.99917686, 0.9999825 ]], dtype=float32)
b: Variable([[-0.00034196, -0.00046545, -0.00027317],
          [-0.00045028, -0.00076825, -0.00026088],
          [-0.0007135 , -0.00020507, -0.00073902]], dtype=float32)

You can process the gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}

grads_pre
{'a': Array([[0.6356058 , 0.10750175, 0.93578255, 0.2557603 ],
        [0.77525663, 0.8615701 , 0.35919654, 0.6861898 ],
        [0.9569112 , 0.98981357, 0.3033744 , 0.62852013],
        [0.36589646, 0.86694443, 0.6335902 , 0.44947362],
        [0.01782513, 0.11465573, 0.5505476 , 0.56196713]], dtype=float32),
 'b': Array([[0.2326113 , 0.14437485, 0.6543677 ],
        [0.46068823, 0.9811108 , 0.30460846],
        [0.261765  , 0.71705794, 0.6173099 ]], dtype=float32)}
grads_post = bm.clip_by_norm(grads_pre, 1.)

grads_post
{'a': Array([[0.22753015, 0.0384828 , 0.33498552, 0.09155546],
        [0.2775215 , 0.30841944, 0.12858291, 0.24563788],
        [0.34254903, 0.3543272 , 0.10860006, 0.22499368],
        [0.13098131, 0.3103433 , 0.22680864, 0.16089973],
        [0.00638093, 0.04104374, 0.19708155, 0.20116945]], dtype=float32),
 'b': Array([[0.14066657, 0.08730751, 0.39571446],
        [0.27859107, 0.5933052 , 0.18420528],
        [0.15829663, 0.433625  , 0.3733046 ]], dtype=float32)}
op.update(grads_post)

print('a:', a)
print('b:', b)
a: Variable([[0.9991351 , 0.9997021 , 0.99951804, 0.99922043],
          [0.99902606, 0.9990364 , 0.99970084, 0.9995283 ],
          [0.9986646 , 0.99939054, 0.99965245, 0.99931115],
          [0.9993875 , 0.9988611 , 0.9987776 , 0.99898726],
          [0.9997404 , 0.99989974, 0.9989798 , 0.9997813 ]], dtype=float32)
b: Variable([[-0.00048263, -0.00055276, -0.00066889],
          [-0.00072887, -0.00136155, -0.00044508],
          [-0.00087179, -0.0006387 , -0.00111233]], dtype=float32)

Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, your dyn_vars in bm.jit() should include variables in Optimzier.vars().

op.vars()  # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable([2], dtype=int32)}
bp.optim.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables
{'Momentum0.a_v': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Momentum0.b_v': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32),
 'Constant1.step': Variable([0], dtype=int32)}
bp.optim.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables
{'Adam0.a_m': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Adam0.b_m': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32),
 'Adam0.a_v': Variable([[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]], dtype=float32),
 'Adam0.b_v': Variable([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=float32),
 'Constant2.step': Variable([0], dtype=int32)}

Creating A Self-Customized Optimizer#

To create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

  • __init__(): init function that receives the learning rate (lr) and trainable variables (train_vars). Do not forget to register your dynamical changed variables into implicit_vars.

  • update(grads): update function that computes the updated parameters.

The general structure is shown below:

class CustomizeOp(bp.optim.Optimizer):
    def __init__(self, lr, train_vars, *params, **other_params):
        super(CustomizeOp, self).__init__(lr, train_vars)
        
        # customize your initialization
        
    def update(self, grads):
        # customize your update logic
        pass

Schedulers#

Scheduler seeks to adjust the learning rate during training through reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

Here we set up an exponential decay scheduler, in which the learning rate will decay exponentially along the training step.

sc = bp.optim.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
    plt.plot(steps, rates)
    plt.xlabel('Train Step')
    plt.ylabel('Learning Rate')
    plt.show()
steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)
_images/d4c6e12211f0bb94f7f05195c0bb03fa1b2d186c95e208f8959d7281d009042a.png

After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr
Constant(0.001)

One can get the current learning rate value by calling Scheduler.__call__(i=None).

  • If i is not provided, the learning rate value will be evaluated at the built-in training step.

  • Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()
0.001

In BrainPy, several commonly used learning rate schedulers are used:

  • Constant

  • ExponentialDecay

  • InverseTimeDecay

  • PolynomialDecay

  • PiecewiseConstant

For more details, please see the brainpy.math.optimizers APIs.

# InverseTimeDecay scheduler

rates = bp.optim.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)
_images/e4e5977b81108bffea5108caed258da31cb0243fdff98d530cea6e070365787c.png
# PolynomialDecay scheduler

rates = bp.optim.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)
_images/6741c56d6b37a41cb910962467af3413fc2dbfe1a4846f3692fc772ef6576d14.png

Creating a Self-Customized Scheduler#

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

  • __init__(): the init function.

  • __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bp.optim.Scheduler):
    def __init__(self, lr, *params, **other_params):
        super(CustomizeScheduler, self).__init__(lr)
        
        # customize your initialization
        
    def __call__(self, i=None):
        # customize your update logic
        pass

Saving and Loading#

@Chaoming Wang

Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.

import brainpy as bp

bp.math.set_platform('cpu')

Saving and loading variables#

Model saving and loading in BrainPy are implemented with .save_states() and .load_states() functions.

BrainPy supports saving and loading model variables with various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

Here’s a simple example:

class EINet(bp.dyn.Network):
    def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
        # neurons
        pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
        E = bp.models.LIF(num_exc, **pars, method=method)
        I = bp.models.LIF(num_inh, **pars, method=method)
        E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
        I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

        # synapses
        E2E = bp.models.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        E2I = bp.models.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        I2E = bp.models.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)
        I2I = bp.models.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)

        super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
        
        
net = EINet()
import os
if not os.path.exists('./data'): 
    os.makedirs('./data')
# model saving

net.save_states('./data/net.h5')
# model loading

net.load_states('./data/net.h5')
  • .save_states(filename, all_var=None) function receives a string to specify the output file name. If all_vars is not provided, BrainPy will retieve all variables in the model though the relative path.

  • .load_states(filename, verbose, check_missing) function receives several arguments. The first is a string of the output file name. The second “verbose” specifies whether report the loading progress. The final argument “check_missing” will warn the variables of the model which missed in the output file.

# model loading with warning and checking

net.load_states('./data/net.h5', verbose=True)
WARNING:brainpy.base.io:There are variable states missed in ./data/net.h5. The missed variables are: ['ExpCOBA0.pre.V', 'ExpCOBA0.pre.input', 'ExpCOBA0.pre.refractory', 'ExpCOBA0.pre.spike', 'ExpCOBA0.pre.t_last_spike', 'ExpCOBA1.pre.V', 'ExpCOBA1.pre.input', 'ExpCOBA1.pre.refractory', 'ExpCOBA1.pre.spike', 'ExpCOBA1.pre.t_last_spike', 'ExpCOBA1.post.V', 'ExpCOBA1.post.input', 'ExpCOBA1.post.refractory', 'ExpCOBA1.post.spike', 'ExpCOBA1.post.t_last_spike', 'ExpCOBA2.pre.V', 'ExpCOBA2.pre.input', 'ExpCOBA2.pre.refractory', 'ExpCOBA2.pre.spike', 'ExpCOBA2.pre.t_last_spike', 'ExpCOBA2.post.V', 'ExpCOBA2.post.input', 'ExpCOBA2.post.refractory', 'ExpCOBA2.post.spike', 'ExpCOBA2.post.t_last_spike', 'ExpCOBA3.pre.V', 'ExpCOBA3.pre.input', 'ExpCOBA3.pre.refractory', 'ExpCOBA3.pre.spike', 'ExpCOBA3.pre.t_last_spike'].
Loading E.V ...
Loading E.input ...
Loading E.refractory ...
Loading E.spike ...
Loading E.t_last_spike ...
Loading ExpCOBA0.g ...
Loading ExpCOBA0.pre_spike.data ...
Loading ExpCOBA0.pre_spike.in_idx ...
Loading ExpCOBA0.pre_spike.out_idx ...
Loading ExpCOBA1.g ...
Loading ExpCOBA1.pre_spike.data ...
Loading ExpCOBA1.pre_spike.in_idx ...
Loading ExpCOBA1.pre_spike.out_idx ...
Loading ExpCOBA2.g ...
Loading ExpCOBA2.pre_spike.data ...
Loading ExpCOBA2.pre_spike.in_idx ...
Loading ExpCOBA2.pre_spike.out_idx ...
Loading ExpCOBA3.g ...
Loading ExpCOBA3.pre_spike.data ...
Loading ExpCOBA3.pre_spike.in_idx ...
Loading ExpCOBA3.pre_spike.out_idx ...
Loading I.V ...
Loading I.input ...
Loading I.refractory ...
Loading I.spike ...
Loading I.t_last_spike ...

Note

By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables.

Custom saving and loading#

You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through .vars(). Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of Variables.

The only gotcha to pay attention to is to avoid saving duplicated variables.

Inputs Construction#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about stimulus inputs.

import brainpy as bp
import brainpy.math as bm

bp.__version__
'2.3.0'

Input construction functions#

Like electrophysiological experiments, model simulation also needs various kind of inputs. BrainPy provide several convenient input functions to help users construct input currents.

1. brainpy.inputs.section_input()#

brainpy.inputs.section_input() is an updated function of previous brainpy.inputs.constant_input() (see below).

Sometimes, we need input currents with different values in different periods. For example, if you want to get an input that is 0 in the first 100 ms, 1 in the next 300 ms, and 0 again from the last 100 ms, you can define:

current1, duration = bp.inputs.section_input(values=[0, 1., 0.],
                                             durations=[100, 300, 100],
                                             return_length=True,
                                             dt=0.1)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Where values receive a list/arrray of the current values in each section and durations receives a list/array of the duration of each section. The function returns a tensor as the current, the length of which is duration\(/\mathrm{d}t\) (if not specified, \(\mathrm{d}t=0.1 \mathrm{ms}\)). We can visualize the current input by:

import numpy as np
import matplotlib.pyplot as plt

def show(current, duration, title):
    ts = np.arange(0, duration, bm.get_dt())
    plt.plot(ts, current)
    plt.title(title)
    plt.xlabel('Time [ms]')
    plt.ylabel('Current Value')
    plt.show()

show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]')
_images/208d5bde0f750d16d38196af4538ceaaac3af9739b8af85eadfb816309a9f4c5.png

2. brainpy.inputs.constant_input()#

brainpy.inputs.constant_input() function helps users to format constant currents in several periods.

We can generate the above input current with constant_input() by:

current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)])

Where each tuple in the list contains the value and duration of the input in this section.

show(current2, duration, '[(0, 100), (1, 300), (0, 100)]')
_images/9ad3a4639d16ee12322c51d63a0776788505b8f9507b8e4bc61f722856a5f731.png

3. brainpy.inputs.spike_input()#

brainpy.inputs.spike_input() constructs an input containing a series of short-time spikes. It receives the following settings:

  • sp_times : The spike time-points. Must be an iterable object. For example, list, tuple, or arrays.

  • sp_lens : The length of each point-current, mimicking the spike durations. It can be a scalar float to specify the unified duration. Or, it can be list/tuple/array of time lengths with the length same with sp_times.

  • sp_sizes : The current sizes. It can be a scalar value. Or, it can be a list/tuple/array of spike current sizes with the length same with sp_times.

  • duration : The total current duration.

  • dt : The time step precision. The default is None (will be initialized as the default dt step).

For example, if you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, where each spike lasts 1 ms and the average value for each spike is 0.5, then you can define the current by:

current3 = bp.inputs.spike_input(
    sp_times=[10, 20, 30, 200, 300],
    sp_lens=1.,  # can be a list to specify the spike length at each point
    sp_sizes=0.5,  # can be a list to specify the spike current size at each point
    duration=400.)

show(current3, 400, 'Spike Input Example')
_images/f0d8e5e26887fe91f5f69ce8927e4b5e3f2eac12112945dbcc11446b731a45bd.png

4. brainpy.inputs.ramp_input()#

brainpy.inputs.ramp_input() mimics a ramp or a step current to the input of the circuit. It receives the following settings:

  • c_start : The minimum (or maximum) current size.

  • c_end : The maximum (or minimum) current size.

  • duration : The total duration.

  • t_start : The ramped current start time-point.

  • t_end : The ramped current end time-point. Default is the None.

  • dt : The current precision.

We illustrate the usage of brainpy.inputs.ramp_input() by two examples.

In the first example, we increase the current size from 0. to 1. between the start time (0 ms) and the end time (500 ms).

duration = 500
current4 = bp.inputs.ramp_input(0, 1, duration)

show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, '
                        r'$t_{start}$=0, $t_{end}$=None' % (duration))
_images/4bec14dd14fa47a413c6c2d85c3fd96329418930e49e74a62de8c7fb19c24e79.png

In the second example, we increase the current size from 0. to 1. from the 100 ms to 400 ms.

duration, t_start, t_end = 500, 100, 400
current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end)

show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, '
                        r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end))
_images/72b44977aec3c8f19c7322a98ed4431b8d48f6919d38ce214ffc09eb35c83331.png

5. brainpy.inputs.wiener_process#

brainpy.inputs.wiener_process() is used to generate the basic Wiener process \(dW\), i.e. random numbers drawn from \(N(0, \sqrt{dt})\).

duration = 200
current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.)
show(current6, duration, 'Wiener Process')
_images/3d4891e2eecb477b535c297c03e8c82053173524425dc444dd4637edfdc10918.png

6. brainpy.inputs.ou_process#

brainpy.inputs.ou_process() is used to generate the noise time series from Ornstein-Uhlenback process \(\dot{x} = (\mu - x)/\tau \cdot dt + \sigma\cdot dW\).

duration = 200
current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.)
show(current7, duration, 'Ornstein-Uhlenbeck Process')
_images/8ec63e2412baa8157f4fa0199195a97eb48ce9a40a798041abd23ff6cde13688.png

7. brainpy.inputs.sinusoidal_input#

brainpy.inputs.sinusoidal_input() can help to generate sinusoidal inputs.

duration = 2000
current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration,  t_start=100., )
show(current8, duration, 'Sinusoidal Input')
_images/8b9feec0dad71469c049a2a6315034e3c5c208655d1881f51103da1d1f8e3a0d.png

8. brainpy.inputs.square_input#

brainpy.inputs.square_input() can help to generate oscillatory square inputs.

duration = 2000
current9 = bp.inputs.square_input(amplitude=1., frequency=2.0,
                                  duration=duration, t_start=100)
show(current9, duration, 'Square Input')
_images/d918d9eb1d0882c4deb1b86df4d7001b4908611abf5b257bbd35d29a6e21f83c.png

More complex inputs#

Because the current input is stored as a tensor, a complex input can be realized by the combination of several simple currents.

show(current1 + current5, 500, 'A Complex Current Input')
_images/c6a759f5ac00839d5bebbc371ab8f1289f1a6c32a4dc20436880f95b8c4e16db.png

General properties of input functions#

1. Every input function receives a dt specification.

If dt is not provided, input functions will use the default dt in the whole BrainPy system.

I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1)
I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01)
print('I1.shape: {}'.format(I1.shape))
print('I2.shape: {}'.format(I2.shape))
I1.shape: (600,)
I2.shape: (6000,)

2. All input functions can automatically broadcast the current shapes if they are heterogenous among different periods.

For example, during period 1 we give an input with a scalar value, during period 2 we give an input with a vector shape, and during period 3 we give a matrix input value. Input functions will broadcast them to the maximum shape. For example:

current = bp.inputs.section_input(values=[0, bm.ones(10), bm.random.random((3, 10))],
                                  durations=[100, 300, 100])

current.shape
(5000, 3, 10)

How to cite BrainPy?#

If BrainPy has been significant in your research, and you would like to acknowledge the project in your academic publication, we suggest citing the following papers:

If you are using BrainPy=2.x, please use:

  • Chaoming Wang, Xiaoyu Chen, Tianqiu Zhang, Si Wu. BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming. bioRxiv 2022.10.28.514024; doi: https://doi.org/10.1101/2022.10.28.514024

@article {Wang2022brainpy,
    author = {Wang, Chaoming and Chen, Xiaoyu and Zhang, Tianqiu and Wu, Si},
    title = {BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming},
    elocation-id = {2022.10.28.514024},
    year = {2022},
    doi = {10.1101/2022.10.28.514024},
    publisher = {Cold Spring Harbor Laboratory},
    URL = {https://www.biorxiv.org/content/early/2022/10/28/2022.10.28.514024},
    eprint = {https://www.biorxiv.org/content/early/2022/10/28/2022.10.28.514024.full.pdf},
    journal = {bioRxiv}
}

If you are using BrainPy=1.x, please use:

  • Wang, C., Jiang, Y., Liu, X., Lin, X., Zou, X., Ji, Z., & Wu, S. (2021, December). A Just-In-Time Compilation Approach for Neural Dynamics Simulation. In International Conference on Neural Information Processing (pp. 15-26). Springer, Cham.

@inproceedings{wang2021just,
  title={A Just-In-Time Compilation Approach for Neural Dynamics Simulation},
  author={Wang, Chaoming and Jiang, Yingqian and Liu, Xinyu and Lin, Xiaohan and Zou, Xiaolong and Ji, Zilong and Wu, Si},
  booktitle={International Conference on Neural Information Processing},
  pages={15--26},
  year={2021},
  organization={Springer}
}

How is brainpy different from other frameworks?#

@Chaoming Wang @Xiaoyu Chen

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

bp.__version__

BrainPy vs Brian2/NEST/NEURON …#

Different from traditional brain simulators (most of them employ a descriptive language for programming brain dynamics models), BrainPy aims to provide the full supports for brain dynamics modeling.

Currently, brain dynamics modeling is far beyond simulation. There are many new modeling approaches which take inspiration from the machine learning community. Moreover, it has also inspired the new development of brain-inspired computation.

These new advances cannot be captured by the traditional brain simulators. Therefore, BrainPy aims to provide an ecosystem for brain dynamics modeling, in which users can build various models easily and flexibly, and extend new modeling approach conveniently, etc.

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code “just-in-time” for execution. Subsequently, such transformed code can run at native machine code speed!

Based on this, BrainPy provides an integrated platform for brain dynamics modeling, including

  • universal model building

  • dynamics simulation

  • dynamics training

  • dynamics analysis

Such integrative framework may help users to study brain dynamics comprehensively.

BrainPy vs JAX/Numba#

BrainPy relies on JAX and Numba. But it also has important aspects which are different from them.

JAX and Numba are excellent JIT compilers in Python. However, they are designed to work only on pure Python functions. Most computational neuroscience models have too many parameters and variables to manage using functions only. Therefore, BrainPy provides an object-oriented programming interface for brain dynamics modeling. There object-oriented transformations are implemented in brainpy.math module.

brainpy.math is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make a better brain dynamics programming framework for Python users.

There are important differences between brainpy.math and JAX and JAX related frameworks.

Specifically, brainpy.math provides:

  1. Numpy-like ndarray.

Python users are familiar with NumPy, especially its ndarray. JAX has similar ndarray structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like x[i] += y. To overcome these drawbacks, brainpy.math provides Array that can be used in the same way as numpy ndarray.

b = bm.arange(5)
b
Array([0, 1, 2, 3, 4], dtype=int32)
b[0] += 5
b
Array([5, 1, 2, 3, 4], dtype=int32)
  1. Numpy-like random sampling.

JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, brainpy.math provides brainpy.math.random for random sampling just like the numpy.random module. For example:

# random sampling in "brainpy.math.random"

bm.random.seed(12345)
bm.random.random(5)
Array([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ],            dtype=float32)
bm.random.normal(0., 2., 5)
Array([-1.5375282, -0.5970201, -2.272839 ,  3.233081 , -0.2738593],            dtype=float32)

For more details, please see the Arrays tutorial.

  1. JAX transformations on class objects.

OOP is the essence of Python. However, JAX’s excellent tranformations (like JIT compilation) only support pure functions. To make them work on object-oriented coding in brain dynamics programming, brainpy.math extends JAX transformations to Python classes. Details please see BrainPy Concept of Object-oriented Transformation.

BrainPy Ecosystem for Brain Dynamics Modeling#

BrainPy aims to build a complete ecosystem for brain dynamics modeling.

Although it has a far way to go, currently we make a progress in:

BrainPy#

Based on JAX, brainpy develops BrainPy to provide universal simulation, training and analysis engine, which serves as a foundation of the whole project. Specifically, BrainPy provides a object-oriented programming interface for brain dynamics modeling.

brainpy-examples#

brainpy-examples provides comprehensive examples for brain dynamics modeling with BrainPy. It implements many classical models introduced in the latest computational neuroscience and brain-inspired computation research.

brainpylib#

brainpylib aims to provide operators specialized for brain dynamics modeling. Brain dynamics features sparse connections and event-driven computation. brainpylib provides dedicated operators for such sparse computation and event-based computation. These operators can be used in computational neuroscience research as well as brain-inspired computation community.

brainpy-datasets#

brainpy-datasets aims to provide commonly used datasets in brain dynamics modeling, including neuromorphic datasets and cognitive tasks for training brain-like neural networks.

brainpy-largescale#

brainpy-largescale provides one solution for large-scale modeling. It enables multi-device running for BrainPy models.

brainpy module#

Numerical Differential Integration#

JointEq(*eqs)

Make a joint equation from multiple derivation functions.

IntegratorRunner(target[, inits, dt, ...])

Structural runner for numerical integrators in brainpy.

odeint([f, method, var_type, dt, name, ...])

Numerical integration for ODEs.

sdeint([f, g, method, dt, name, show_code, ...])

Numerical integration for SDEs.

fdeint(alpha, num_memory, inits[, f, ...])

Numerical integration for FDEs.

Building Dynamical System#

DynamicalSystem([name, mode])

Base Dynamical System class.

Container(*dynamical_systems_as_tuple[, ...])

Container object which is designed to add other instances of DynamicalSystem.

Sequential(*modules_as_tuple[, name, mode])

A sequential input-output module.

Network(*ds_tuple[, name, mode])

Base class to model network objects, an alias of Container.

NeuGroup(size[, keep_size, name, mode])

Base class to model neuronal groups.

SynConn(pre, post[, conn, name, mode])

Base class to model two-end synaptic connections.

SynOut([name, target_var])

Base class for synaptic current output.

SynSTP(*args, **kwargs)

Base class for synaptic short-term plasticity.

SynLTP(*args, **kwargs)

Base class for synaptic long-term plasticity.

TwoEndConn(pre, post[, conn, output, stp, ...])

Base class to model synaptic connections.

CondNeuGroup(size[, keep_size, C, A, V_th, ...])

Base class to model conductance-based neuron group.

Channel(size[, name, keep_size, mode])

Abstract channel class.

Simulating Dynamical System#

DSRunner(target[, inputs, monitors, ...])

The runner for DynamicalSystem.

Training Dynamical System#

DSTrainer(target, **kwargs)

Structural Trainer for Dynamical Systems.

BPTT(target, loss_fun[, optimizer, ...])

The trainer implementing the back-propagation through time (BPTT) algorithm for training dyamical systems.

BPFF(target, loss_fun[, optimizer, ...])

The trainer implementing back propagation algorithm for feedforward neural networks.

OnlineTrainer(target[, fit_method])

Online trainer for models with recurrent dynamics.

ForceTrainer(target[, alpha])

FORCE learning.

OfflineTrainer(target[, fit_method])

Offline trainer for models with recurrent dynamics.

RidgeTrainer(target[, alpha])

Trainer of ridge regression, also known as regression with Tikhonov regularization.

Dynamical System Helpers#

DSPartial(target, *args[, child_objs, ...])

NoSharedArg(target[, name])

Transform an instance of DynamicalSystem into a callable BrainPyObject \(y=f(x)\).

LoopOverTime(target[, out_vars, no_state, name])

Transform a single step DynamicalSystem into a multiple-step forward propagation BrainPyObject.

brainpy.math module#

Basis for Object-oriented Transformations#

BrainPyObject([name])

The BrainPyObject class for whole BrainPy ecosystem.

FunAsObject(target[, child_objs, dyn_vars, name])

Transform a Python function as a BrainPyObject.

dyn_seq([iterable])

A list to represent a dynamically changed numerical sequence in which its element can be changed during JIT compilation.

dyn_dict

A dict to represent a dynamically changed numerical dictionary in which its element can be changed during JIT compilation.

Variable(value_or_size[, dtype, batch_axis])

The pointer to specify the dynamical variable.

Parameter(value_or_size[, dtype, batch_axis])

The pointer to specify the parameter.

TrainVar(value_or_size[, dtype, batch_axis])

The pointer to specify the trainable variable.

Partial(fun, *args[, child_objs, dyn_vars])

Object-oriented Transformations#

grad(func[, grad_vars, dyn_vars, ...])

Automatic gradient computation for functions or class objects.

vector_grad(func[, grad_vars, dyn_vars, ...])

Take vector-valued gradients for function func.

jacobian(func[, grad_vars, dyn_vars, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacrev(func[, grad_vars, dyn_vars, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacfwd(func[, grad_vars, dyn_vars, ...])

Extending automatic Jacobian (forward-mode) of func to classes.

hessian(func[, grad_vars, dyn_vars, ...])

Hessian of func as a dense array.

make_loop(body_fun, dyn_vars[, out_vars, ...])

Make a for-loop function, which iterate over inputs.

make_while(cond_fun, body_fun, dyn_vars)

Make a while-loop function.

make_cond(true_fun, false_fun[, dyn_vars])

Make a condition (if-else) function.

cond(pred, true_fun, false_fun, operands[, ...])

Simple conditional statement (if-else) with instance of Variable.

ifelse(conditions, branches[, operands, ...])

If-else control flows looks like native Pythonic programming.

for_loop(body_fun, operands[, dyn_vars, ...])

for-loop control flow with Variable.

while_loop(body_fun, cond_fun, operands[, ...])

while-loop control flow with Variable.

to_object([f, child_objs, dyn_vars, name])

Transform a Python function to BrainPyObject.

to_dynsys([f, child_objs, dyn_vars, name])

Transform a Python function to a DynamicalSystem.

function([f, nodes, dyn_vars, name])

Transform a Python function into a BrainPyObject.

jit(func[, dyn_vars, child_objs, ...])

JIT (Just-In-Time) compilation for class objects.

ObjectTransform([name])

Object-oriented JAX transformation for BrainPy computation.

Brain Dynamics Dedicated Operators#

pre2post_sum(pre_values, post_num, post_ids)

The pre-to-post synaptic summation.

pre2post_prod(pre_values, post_num, post_ids)

The pre-to-post synaptic production.

pre2post_max(pre_values, post_num, post_ids)

The pre-to-post synaptic maximization.

pre2post_min(pre_values, post_num, post_ids)

The pre-to-post synaptic minimization.

pre2post_mean(pre_values, post_num, post_ids)

The pre-to-post synaptic mean computation.

pre2post_event_sum(events, pre2post, post_num)

The pre-to-post event-driven synaptic summation with CSR synapse structure.

pre2post_coo_event_sum(events, pre_ids, ...)

The pre-to-post synaptic computation with event-driven summation.

pre2post_event_prod(events, pre2post, post_num)

The pre-to-post synaptic computation with event-driven production.

pre2syn(pre_values, pre_ids)

The pre-to-syn computation.

syn2post_sum(syn_values, post_ids, post_num)

The syn-to-post summation computation.

syn2post(syn_values, post_ids, post_num[, ...])

The syn-to-post summation computation.

syn2post_prod(syn_values, post_ids, post_num)

The syn-to-post product computation.

syn2post_max(syn_values, post_ids, post_num)

The syn-to-post maximum computation.

syn2post_min(syn_values, post_ids, post_num)

The syn-to-post minimization computation.

syn2post_mean(syn_values, post_ids, post_num)

The syn-to-post mean computation.

syn2post_softmax(syn_values, post_ids, post_num)

The syn-to-post softmax computation.

sparse_matmul(A, B)

Sparse matrix multiplication.

csr_matvec(values, indices, indptr, vector, ...)

Product of CSR sparse matrix and a dense vector.

event_csr_matvec(values, indices, indptr, ...)

The pre-to-post event-driven synaptic summation with CSR synapse structure.

segment_sum(data, segment_ids[, ...])

segment_sum operator for brainpy Array and Variable.

segment_prod(data, segment_ids[, ...])

segment_prod operator for brainpy Array and Variable.

segment_max(data, segment_ids[, ...])

segment_max operator for brainpy Array and Variable.

segment_min(data, segment_ids[, ...])

segment_min operator for brainpy Array and Variable.

XLACustomOp([eval_shape, con_compute, ...])

Creating a XLA custom call operator.

Activation Functions#

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU activation function

hard_swish(x)

Hard SiLU activation function

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis])

Log-Softmax function.

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indicies.

normalize(x[, axis, mean, variance, epsilon])

Normalizes an array by subtracting mean and dividing by sqrt(var).

relu(x)

relu6(x)

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis])

Softmax function.

softplus(x)

Softplus activation function.

silu(x)

SiLU activation function.

swish(x)

SiLU activation function.

selu(x)

Scaled exponential linear unit activation.

identity(n[, dtype])

Return the identity array.

tanh(x)

Compute hyperbolic tangent element-wise.

Array Operations#

flatten(input[, start_dim, end_dim])

Flattens input by reshaping it into a one-dimensional tensor.

fill_diagonal(a, val)

remove_diag(arr)

Remove the diagonal of the matrix.

clip_by_norm(t, clip_norm[, axis])

empty(shape[, dtype])

Return a new array of given shape and type, without initializing entries.

empty_like(prototype[, dtype, shape])

Return a new array with the same shape and type as a given array.

ones(shape[, dtype])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape])

Return an array of ones with the same shape and type as a given array.

zeros(shape[, dtype])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape])

Return an array of zeros with the same shape and type as a given array.

array(object[, dtype, copy, order, ndmin])

Create an array.

asarray(a[, dtype, order])

Convert the input to an array.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

linspace(start, stop[, num, endpoint, ...])

Return evenly spaced numbers over a specified interval.

logspace(start, stop[, num, endpoint, base, ...])

Return numbers spaced evenly on a log scale.

as_device_array(tensor[, dtype])

Convert the input to a jax.numpy.DeviceArray.

as_jax(tensor[, dtype])

Convert the input to a jax.numpy.DeviceArray.

as_ndarray(tensor[, dtype])

Convert the input to a numpy.ndarray.

as_numpy(tensor[, dtype])

Convert the input to a numpy.ndarray.

as_variable(tensor[, dtype])

Convert the input to a brainpy.math.Variable.

Delay Variables#

TimeDelay(delay_target, delay_len[, ...])

Delay variable which has a fixed delay time length.

LengthDelay(delay_target, delay_len[, ...])

Delay variable which has a fixed delay length.

NeuTimeDelay(delay_target, delay_len[, ...])

Neutral Time Delay.

NeuLenDelay(delay_target, delay_len[, ...])

Neutral Length Delay.

Environment Settings#

set_float(dtype)

Set global default float type.

get_float()

Get the default float data type.

set_int(dtype)

Set global default integer type.

get_int()

Get the default int data type.

set_bool(dtype)

Set global default boolean type.

get_bool()

Get the default boolean data type.

set_complex(dtype)

Set global default complex type.

get_complex()

Get the default complex data type.

set_dt(dt)

Set the default numerical integrator precision.

get_dt()

Get the numerical integrator precision.

set_mode(mode)

Set the default computing mode.

get_mode()

Get the default computing mode.

set_environment([mode, dt, x64, complex_, ...])

Set the default computation environment.

enable_x64()

disable_x64()

set_platform(platform)

Changes platform to CPU, GPU, or TPU.

get_platform()

Get the computing platform.

set_host_device_count(n)

By default, XLA considers all CPU cores as one device.

clear_buffer_memory([platform])

Clear all on-device buffers.

enable_gpu_memory_preallocation()

Disable pre-allocating the GPU memory.

disable_gpu_memory_preallocation()

Disable pre-allocating the GPU memory.

ditype()

Default int type.

dftype()

Default float type.

environment([mode, dt, x64, complex_, ...])

Context-manager that sets a computing environment for brain dynamics computation.

batching_environment([dt, x64, complex_, ...])

Environment with the batching mode.

training_environment([dt, x64, complex_, ...])

Environment with the training mode.

Computing Modes#

Mode()

Base class for computation Mode

NonBatchingMode()

Normal non-batching mode.

BatchingMode()

Batching mode.

TrainingMode()

Training mode requires data batching.

nonbatching_mode

Normal non-batching mode.

batching_mode

Batching mode.

training_mode

Training mode requires data batching.

brainpy.math.random module#

seed([seed])

split_key()

default_rng([seed_or_key, clone])

rtype

RandomState

rand(*dn[, key])

randint(low[, high, size, dtype, key])

random_integers(low[, high, size, key])

randn(*dn[, key])

random([size, key])

random_sample([size, key])

ranf([size, key])

sample([size, key])

choice(a[, size, replace, p, key])

permutation(x[, axis, independent, key])

shuffle(x[, axis, key])

beta(a, b[, size, key])

exponential([scale, size, key])

gamma(shape[, scale, size, key])

gumbel([loc, scale, size, key])

laplace([loc, scale, size, key])

logistic([loc, scale, size, key])

normal([loc, scale, size, key])

pareto(a[, size, key])

poisson([lam, size, key])

standard_cauchy([size, key])

standard_exponential([size, key])

standard_gamma(shape[, size, key])

standard_normal([size, key])

standard_t(df[, size, key])

uniform([low, high, size, key])

truncated_normal(lower, upper[, size, ...])

Sample truncated standard normal random values with given shape and dtype.

bernoulli([p, size, key])

Sample Bernoulli random values with given shape and mean.

lognormal([mean, sigma, size, key])

binomial(n, p[, size, key])

chisquare(df[, size, key])

dirichlet(alpha[, size, key])

geometric(p[, size, key])

f(dfnum, dfden[, size, key])

hypergeometric(ngood, nbad, nsample[, size, key])

logseries(p[, size, key])

multinomial(n, pvals[, size, key])

multivariate_normal(mean, cov[, size, ...])

negative_binomial(n, p[, size, key])

noncentral_chisquare(df, nonc[, size, key])

noncentral_f(dfnum, dfden, nonc[, size, key])

power(a[, size, key])

rayleigh([scale, size, key])

triangular([size, key])

vonmises(mu, kappa[, size, key])

wald(mean, scale[, size, key])

weibull(a[, size, key])

weibull_min(a[, scale, size, key])

zipf(a[, size, key])

maxwell([size, key])

t(df[, size, key])

Sample Student’s t random values.

orthogonal(n[, size, key])

Sample uniformly from the orthogonal group O(n).

loggamma(a[, size, key])

Sample log-gamma random values.

categorical(logits[, axis, size, key])

RandomState([seed_or_key, seed])

RandomState that track the random generator state.

Generator

alias of brainpy._src.math.random.RandomState

DEFAULT

RandomState that track the random generator state.

brainpy.math.surrogate module#

sigmoid

Spike function with the sigmoid-shaped surrogate gradient.

piecewise_quadratic

Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.

piecewise_exp

Judge spiking state with a piecewise exponential function [1]_.

soft_sign

Judge spiking state with a soft sign function.

arctan

Judge spiking state with an arctan function.

nonzero_sign_log

Judge spiking state with a nonzero sign log function.

erf

Judge spiking state with an erf function [1]_ [2]_ [3]_.

piecewise_leaky_relu

Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.

squarewave_fourier_series

Judge spiking state with a squarewave fourier series.

s2nn

Judge spiking state with the S2NN surrogate spiking function [1]_.

q_pseudo_spike

Judge spiking state with the q-PseudoSpike surrogate function [1]_.

leaky_relu

Judge spiking state with the Leaky ReLU function.

log_tailed_relu

Judge spiking state with the Log-tailed ReLU function [1]_.

relu_grad

Spike function with the ReLU gradient function [1]_.

gaussian_grad

Spike function with the Gaussian gradient function [1]_.

inv_square_grad

Spike function with the inverse-square surrogate gradient.

multi_gaussian_grad

Spike function with the multi-Gaussian gradient function [1]_.

slayer_grad

Spike function with the slayer surrogate gradient function.

inv_square_grad2

relu_grad2

brainpy.channels module#

Basic Channel Classes#

Ion(size[, name, keep_size, mode])

Base class for ions.

IonChannel(size[, name, keep_size, mode])

Base class for ion channels.

Calcium(size[, keep_size, method, name, mode])

The brainpy_object calcium dynamics.

IhChannel(size[, name, keep_size, mode])

Base class for Ih channel models.

CalciumChannel(size[, name, keep_size, mode])

Base class for Calcium ion channels.

SodiumChannel(size[, name, keep_size, mode])

Base class for sodium channel.

PotassiumChannel(size[, name, keep_size, mode])

Base class for potassium channel.

LeakyChannel(size[, name, keep_size, mode])

Base class for leaky channel.

Voltage-dependent Sodium Channel Models#

INa_Ba2002(size[, keep_size, T, E, g_max, ...])

The sodium current model.

INa_TM1991(size[, keep_size, E, g_max, phi, ...])

The sodium current model described by (Traub and Miles, 1991) [1]_.

INa_HH1952(size[, keep_size, E, g_max, phi, ...])

The sodium current model described by Hodgkin–Huxley model [1]_.

Voltage-dependent Potassium Channel Models#

IKDR_Ba2002(size[, keep_size, E, g_max, ...])

The delayed rectifier potassium channel current.

IK_TM1991(size[, keep_size, E, g_max, phi, ...])

The potassium channel described by (Traub and Miles, 1991) [1]_.

IK_HH1952(size[, keep_size, E, g_max, phi, ...])

The potassium channel described by Hodgkin–Huxley model [1]_.

IKA1_HM1992(size[, keep_size, E, g_max, ...])

The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_.

IKA2_HM1992(size[, keep_size, E, g_max, ...])

The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_.

IKK2A_HM1992(size[, keep_size, E, g_max, ...])

The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_.

IKK2B_HM1992(size[, keep_size, E, g_max, ...])

The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_.

IKNI_Ya1989(size[, keep_size, E, g_max, ...])

A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.

Voltage-dependent Calcium Channel Models#

CalciumFixed(size[, keep_size, E, C, ...])

Fixed Calcium dynamics.

CalciumDyna(size[, keep_size, C0, T, ...])

Calcium ion flow with dynamics.

CalciumDetailed(size[, keep_size, T, d, ...])

Dynamical Calcium model proposed.

CalciumFirstOrder(size[, keep_size, T, ...])

The first-order calcium concentration model.

ICaN_IS2008(size[, keep_size, E, g_max, ...])

The calcium-activated non-selective cation channel model proposed by (Inoue & Strowbridge, 2008) [2]_.

ICaT_HM1992(size[, keep_size, T, T_base_p, ...])

The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_.

ICaT_HP1992(size[, keep_size, T, T_base_p, ...])

The low-threshold T-type calcium current model for thalamic reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_.

ICaHT_HM1992(size[, keep_size, T, T_base_p, ...])

The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_.

ICaL_IS2008(size[, keep_size, T, T_base_p, ...])

The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_.

Calcium-dependent Potassium Channel Models#

IAHP_De1994(size[, keep_size, E, n, g_max, ...])

The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_.

Hyperpolarization-activated Cation Channel Models#

Ih_HM1992(size[, keep_size, g_max, E, phi, ...])

The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_.

Ih_De1996(size[, keep_size, E, k2, k4, ...])

The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_.

Leakage Channel Models#

IL(size[, keep_size, g_max, E, method, ...])

The leakage channel current.

IKL(size[, keep_size, g_max, E, method, ...])

The potassium leak channel current.

brainpy.layers module#

Basic ANN Layer Class#

Layer([name, mode])

Base class for a layer of artificial neural network.

Convolutional Layers#

Conv1d(in_channels, out_channels, kernel_size)

One-dimensional convolution.

Conv2d(in_channels, out_channels, kernel_size)

Two-dimensional convolution.

Conv3d(in_channels, out_channels, kernel_size)

Three-dimensional convolution.

Dropout Layers#

Dropout(prob[, seed, mode, name])

A layer that stochastically ignores a subset of inputs each training step.

Function Layers#

Activation(activate_fun[, name, mode])

Applies an activation function to the inputs

Flatten([name, mode])

Flattens a contiguous range of dims into 2D or 1D.

FunAsLayer(fun[, name, mode, has_shared])

Dense Connection Layers#

Dense(num_in, num_out[, W_initializer, ...])

A linear transformation applied over the last dimension of the input.

Normalization Layers#

BatchNorm1d(num_features[, axis, epsilon, ...])

1-D batch normalization [1]_.

BatchNorm2d(num_features[, axis, epsilon, ...])

2-D batch normalization [1]_.

BatchNorm3d(num_features[, axis, epsilon, ...])

3-D batch normalization [1]_.

LayerNorm(normalized_shape[, epsilon, ...])

Layer normalization (https://arxiv.org/abs/1607.06450).

GroupNorm(num_groups, num_channels[, ...])

Group normalization layer.

InstanceNorm(num_channels[, epsilon, ...])

Instance normalization layer.

NVAR Layers#

NVAR(num_in, delay[, order, stride, ...])

Nonlinear vector auto-regression (NVAR) node.

Pooling Layers#

MaxPool(kernel_size[, stride, padding, ...])

Pools the input by taking the maximum over a window.

MinPool(kernel_size[, stride, padding, ...])

Pools the input by taking the minimum over a window.

AvgPool(kernel_size[, stride, padding, ...])

Pools the input by taking the average over a window.

AvgPool1d(kernel_size[, stride, padding, ...])

Applies a 1D average pooling over an input signal composed of several input

AvgPool2d(kernel_size[, stride, padding, ...])

Applies a 2D average pooling over an input signal composed of several input

AvgPool3d(kernel_size[, stride, padding, ...])

Applies a 3D average pooling over an input signal composed of several input

MaxPool1d(kernel_size[, stride, padding, ...])

Applies a 1D max pooling over an input signal composed of several input

MaxPool2d(kernel_size[, stride, padding, ...])

Applies a 1D max pooling over an input signal composed of several input

MaxPool3d(kernel_size[, stride, padding, ...])

Applies a 1D max pooling over an input signal composed of several input

AdaptiveAvgPool1d(target_shape[, ...])

Adaptive one-dimensional average down-sampling.

AdaptiveAvgPool2d(target_shape[, ...])

Adaptive two-dimensional average down-sampling.

AdaptiveAvgPool3d(target_shape[, ...])

Adaptive three-dimensional average down-sampling.

AdaptiveMaxPool1d(target_shape[, ...])

Adaptive one-dimensional maximum down-sampling.

AdaptiveMaxPool2d(target_shape[, ...])

Adaptive two-dimensional maximum down-sampling.

AdaptiveMaxPool3d(target_shape[, ...])

Adaptive three-dimensional maximum down-sampling.

Reservoir Layers#

Reservoir(input_shape, num_out[, ...])

Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_.

Artificial Recurrent Layers#

RNNCell(num_in, num_out[, ...])

Basic fully-connected RNN core.

GRUCell(num_in, num_out[, Wi_initializer, ...])

Gated Recurrent Unit.

LSTMCell(num_in, num_out[, Wi_initializer, ...])

Long short-term memory (LSTM) RNN core.

VanillaRNN(*args, **kwargs)

Vanilla RNN.

GRU(*args, **kwargs)

GRU.

LSTM(*args, **kwargs)

LSTM.

brainpy.neurons module#

Biological Models#

HH(size[, keep_size, ENa, gNa, EK, gK, EL, ...])

Hodgkin–Huxley neuron model.

MorrisLecar(size[, keep_size, V_Ca, g_Ca, ...])

The Morris-Lecar neuron model.

PinskyRinzelModel(size[, keep_size, gNa, ...])

The Pinsky and Rinsel (1994) model.

WangBuzsakiModel(size[, keep_size, ENa, ...])

Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model.

Fractional-order Models#

FractionalNeuron(size[, keep_size, name, mode])

Fractional-order neuron model.

FractionalFHR(size, alpha[, num_memory, a, ...])

The fractional-order FH-R model [1]_.

FractionalIzhikevich(size, alpha, num_memory)

Fractional-order Izhikevich model [10]_.

Reduced Models#

LeakyIntegrator(size[, keep_size, V_rest, ...])

Leaky Integrator Model.

LIF(size[, keep_size, V_rest, V_reset, ...])

Leaky integrate-and-fire neuron model.

ExpIF(size[, V_rest, V_reset, V_th, V_T, ...])

Exponential integrate-and-fire neuron model.

AdExIF(size[, V_rest, V_reset, V_th, V_T, ...])

Adaptive exponential integrate-and-fire neuron model.

QuaIF(size[, V_rest, V_reset, V_th, V_c, c, ...])

Quadratic Integrate-and-Fire neuron model.

AdQuaIF(size[, V_rest, V_reset, V_th, V_c, ...])

Adaptive quadratic integrate-and-fire neuron model.

GIF(size[, V_rest, V_reset, V_th_inf, ...])

Generalized Integrate-and-Fire model.

ALIFBellec2020(size[, keep_size, V_rest, ...])

Leaky Integrate-and-Fire model with SFA [1]_.

Izhikevich(size[, a, b, c, d, V_th, ...])

The Izhikevich neuron model.

HindmarshRose(size[, a, b, c, d, r, s, ...])

Hindmarsh-Rose neuron model.

FHN(size[, a, b, tau, Vth, V_initializer, ...])

FitzHugh-Nagumo neuron model.

Noise Models#

OUProcess(size[, mean, sigma, tau, method, ...])

The Ornstein–Uhlenbeck process.

Input Models#

InputGroup(size[, keep_size, mode, name])

Input neuron group for place holder.

OutputGroup(size[, keep_size, mode, name])

Output neuron group for place holder.

SpikeTimeGroup(size, times, indices[, ...])

The input neuron group characterized by spikes emitting at given times.

PoissonGroup(size, freqs[, seed, keep_size, ...])

Poisson Neuron Group.

brainpy.rates module#

RateModel(size[, keep_size, name, mode])

FHN(size[, keep_size, alpha, beta, gamma, ...])

FitzHugh-Nagumo system used in [1]_.

FeedbackFHN(size[, keep_size, a, b, delay, ...])

FitzHugh-Nagumo model with recurrent neural feedback.

QIF(size[, keep_size, tau, eta, delta, J, ...])

A mean-field model of a quadratic integrate-and-fire neuron population.

StuartLandauOscillator(size[, keep_size, a, ...])

Stuart-Landau model with Hopf bifurcation.

WilsonCowanModel(size[, keep_size, E_tau, ...])

Wilson-Cowan population model.

ThresholdLinearModel(size[, tau_e, tau_i, ...])

A threshold linear rate model.

brainpy.synapses module#

Abstract Models#

Delta(pre, post, conn[, output, stp, ...])

Voltage Jump Synapse Model, or alias of Delta Synapse Model.

Exponential(pre, post, conn[, output, stp, ...])

Exponential decay synapse model.

DualExponential(pre, post, conn[, stp, ...])

Dual exponential synapse model.

Alpha(pre, post, conn[, output, stp, ...])

Alpha synapse model.

NMDA(pre, post, conn[, output, stp, ...])

NMDA synapse model.

PoissonInput(target_var, num_input, freq, weight)

Poisson Input to the given Variable.

Biological Models#

AMPA(pre, post, conn[, output, stp, ...])

AMPA synapse model.

GABAa(pre, post, conn[, output, stp, ...])

GABAa synapse model.

BioNMDA(pre, post, conn[, output, stp, ...])

Biological NMDA synapse model.

Coupling Models#

DelayCoupling(delay_var, var_to_output, ...)

Delay coupling.

DiffusiveCoupling(coupling_var1, ...[, ...])

Diffusive coupling.

AdditiveCoupling(coupling_var, ...[, ...])

Additive coupling.

Gap Junction Models#

GapJunction(pre, post, conn[, comp_method, ...])

Learning Rule Models#

brainpy.synouts module#

COBA([E, target_var, membrane_var, name])

Conductance-based synaptic output.

CUBA([target_var, name])

Current-based synaptic output.

MgBlock([E, cc_Mg, alpha, beta, target_var, ...])

Synaptic output based on Magnesium blocking.

brainpy.synplast module#

STD([tau, U, method, name])

Synaptic output with short-term depression.

STP([U, tau_f, tau_d, method, name])

Synaptic output with short-term plasticity.

brainpy.integrators module#

ODE integrators#

Base ODE Integrator#

ODEIntegrator(f[, var_type, dt, name, ...])

Numerical Integrator for Ordinary Differential Equations (ODEs).

Generic ODE Functions#

set_default_odeint(method)

Set the default ODE numerical integrator method for differential equations.

get_default_odeint()

Get the default ODE numerical integrator method.

register_ode_integrator(name, integrator)

Register a new ODE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Explicit Runge-Kutta ODE Integrators#

ExplicitRKIntegrator(f[, var_type, dt, ...])

Explicit Runge–Kutta methods for ordinary differential equation.

Euler(f[, var_type, dt, name, show_code, ...])

The Euler method for ODEs.

MidPoint(f[, var_type, dt, name, show_code, ...])

Explicit midpoint method for ODEs.

Heun2(f[, var_type, dt, name, show_code, ...])

Heun's method for ODEs.

Ralston2(f[, var_type, dt, name, show_code, ...])

Ralston's method for ODEs.

RK2(f[, beta, var_type, dt, name, ...])

Generic second order Runge-Kutta method for ODEs.

RK3(f[, var_type, dt, name, show_code, ...])

Classical third-order Runge-Kutta method for ODEs.

Heun3(f[, var_type, dt, name, show_code, ...])

Heun's third-order method for ODEs.

Ralston3(f[, var_type, dt, name, show_code, ...])

Ralston's third-order method for ODEs.

SSPRK3(f[, var_type, dt, name, show_code, ...])

Third-order Strong Stability Preserving Runge-Kutta (SSPRK3).

RK4(f[, var_type, dt, name, show_code, ...])

Classical fourth-order Runge-Kutta method for ODEs.

Ralston4(f[, var_type, dt, name, show_code, ...])

Ralston's fourth-order method for ODEs.

RK4Rule38(f[, var_type, dt, name, ...])

3/8-rule fourth-order method for ODEs.

Adaptive Runge-Kutta ODE Integrators#

AdaptiveRKIntegrator(f[, var_type, dt, ...])

Adaptive Runge-Kutta method for ordinary differential equations.

RKF12(f[, var_type, dt, name, adaptive, ...])

The Fehlberg RK1(2) method for ODEs.

RKF45(f[, var_type, dt, name, adaptive, ...])

The Runge–Kutta–Fehlberg method for ODEs.

DormandPrince(f[, var_type, dt, name, ...])

The Dormand–Prince method for ODEs.

CashKarp(f[, var_type, dt, name, adaptive, ...])

The Cash–Karp method for ODEs.

BogackiShampine(f[, var_type, dt, name, ...])

The Bogacki–Shampine method for ODEs.

HeunEuler(f[, var_type, dt, name, adaptive, ...])

The Heun–Euler method for ODEs.

Exponential ODE Integrators#

ExponentialEuler(f[, var_type, dt, name, ...])

Exponential Euler method using automatic differentiation.

SDE integrators#

Base SDE Integrator#

SDEIntegrator(f, g[, dt, name, show_code, ...])

SDE Integrator.

Generic SDE Functions#

set_default_sdeint(method)

Set the default SDE numerical integrator method for differential equations.

get_default_sdeint()

Get the default SDE numerical integrator method.

register_sde_integrator(name, integrator)

Register a new SDE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Normal SDE Integrators#

Euler(f, g[, dt, name, show_code, var_type, ...])

Euler method for the Ito and Stratonovich integrals.

Heun(f, g[, dt, name, show_code, var_type, ...])

The Euler-Heun method for Stratonovich integral scheme.

Milstein(f, g[, dt, name, show_code, ...])

Milstein method for Ito or Stratonovich integrals.

MilsteinGradFree(f, g[, dt, name, ...])

Derivative-free Milstein method for Ito or Stratonovich integrals.

ExponentialEuler(f, g[, dt, name, ...])

First order, explicit exponential Euler method.

SRK methods for scalar Wiener process#

SRK1W1(f, g[, dt, name, show_code, ...])

Order 2.0 weak SRK methods for SDEs with scalar Wiener process.

SRK2W1(f, g[, dt, name, show_code, ...])

Order 1.5 Strong SRK Methods for SDEs with Scalar Noise.

KlPl(f, g[, dt, name, show_code, var_type, ...])

FDE integrators#

Base FDE Integrator#

FDEIntegrator(f, alpha, num_memory[, dt, ...])

Numerical integrator for fractional differential equations (FEDs).

Generic FDE Functions#

set_default_fdeint(method)

Set the default ODE numerical integrator method for differential equations.

get_default_fdeint()

Get the default ODE numerical integrator method.

register_fde_integrator(name, integrator)

Register a new ODE integrator.

get_supported_methods()

Get all supported numerical methods for DDEs.

Methods for Caputo Fractional Derivative#

CaputoEuler(f, alpha, num_memory, inits[, ...])

One-step Euler method for Caputo fractional differential equations.

CaputoL1Schema(f, alpha, num_memory, inits)

The L1 scheme method for the numerical approximation of the Caputo fractional-order derivative equations [3]_.

Methods for Riemann-Liouville Fractional Derivative#

GLShortMemory(f, alpha, inits, num_memory[, ...])

Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_.

brainpy.analysis module#

Low-dimensional Analyzers#

PhasePlane1D(model, target_vars[, ...])

Phase plane analyzer for 1D dynamical system.

PhasePlane2D(model, target_vars[, ...])

Phase plane analyzer for 2D dynamical system.

Bifurcation1D(model, target_pars, target_vars)

Bifurcation analysis of 1D system.

Bifurcation2D(model, target_pars, target_vars)

Bifurcation analysis of 2D system.

FastSlow1D(model, fast_vars, slow_vars[, ...])

FastSlow2D(model, fast_vars, slow_vars[, ...])

High-dimensional Analyzers#

SlowPointFinder(f_cell[, f_type, f_loss, ...])

Find fixed/slow points by numerical optimization.

brainpy.connect module#

Base Connection Classes and Tools#

mat2coo(dense)

mat2csc(dense)

mat2csr(dense)

convert a dense matrix to (indices, indptr).

csr2csc(csr, post_num[, data])

Convert csr to csc.

csr2mat(csr, num_pre, num_post)

convert (indices, indptr) to a dense matrix.

csr2coo(csr)

coo2csr(coo, num_pre)

convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'

coo2csc(coo, post_num[, data])

Convert csr to csc.

coo2mat(ij, num_pre, num_post)

convert (indices, indptr) to a dense matrix.

Connector()

Base Synaptic Connector Class.

TwoEndConnector([pre, post])

Synaptic connector to build connections between two neuron groups.

OneEndConnector(*args, **kwargs)

Synaptic connector to build synapse connections within a population of neurons.

CONN_MAT

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE_IDS

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST_IDS

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE2POST

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST2PRE

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

PRE2SYN

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

POST2SYN

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Custom Connections#

MatConn(conn_mat, **kwargs)

Connector built from the dense connection matrix.

IJConn(i, j, **kwargs)

Connector built from the pre_ids and post_ids connections.

CSRConn(indices, inptr, **kwargs)

Connector built from the CSR sparse connection matrix.

SparseMatConn(csr_mat, **kwargs)

Connector built from the sparse connection matrix

Random Connections#

FixedProb(prob[, pre_ratio, include_self, ...])

Connect the post-synaptic neurons with fixed probability.

FixedPreNum(num[, include_self, ...])

Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron.

FixedPostNum(num[, include_self, ...])

Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron.

FixedTotalNum(num[, seed])

Connect the synaptic neurons with fixed total number.

GaussianProb(sigma[, encoding_values, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the connection probability decay according to the gaussian function.

ProbDist([dist, prob, pre_ratio, seed, ...])

Connection with a maximum distance under a probability p.

SmallWorld(num_neighbor, prob[, directed, ...])

Build a Watts–Strogatz small-world graph.

ScaleFreeBA(m[, directed, seed])

Build a random graph according to the Barabási–Albert preferential attachment model.

ScaleFreeBADual(m1, m2, p[, directed, seed])

Build a random graph according to the dual Barabási–Albert preferential attachment model.

PowerLaw(m, p[, directed, seed])

Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering.

Regular Connections#

One2One(*args, **kwargs)

Connect two neuron groups one by one.

All2All(*args[, include_self])

Connect each neuron in first group to all neurons in the post-synaptic neuron groups.

GridFour([include_self, periodic_boundary])

The nearest four neighbors connection method.

GridEight([include_self, periodic_boundary])

The nearest eight neighbors conn method.

GridN([N, include_self, periodic_boundary])

The nearest (2*N+1) * (2*N+1) neighbors conn method.

one2one

Connect two neuron groups one by one.

all2all

Connect each neuron in first group to all neurons in the post-synaptic neuron groups.

grid_four

The nearest four neighbors connection method.

grid_eight

The nearest eight neighbors conn method.

brainpy.encoding module#

Encoder([name])

Base class for encoding rate values as spike trains.

LatencyEncoder(min_val, max_val, num_period)

Encode the rate input as the spike train.

WeightedPhaseEncoder(min_val, max_val, num_phase)

Encode the rate input into the spike train according to [1]_.

PoissonEncoder([min_val, max_val, seed])

Encode the rate input as the Poisson spike train.

brainpy.initialize module#

This module provides methods to initialize weights. You can access them through brainpy.init.XXX.

Basic Initialization Classes#

Initializer()

Base Initialization Class.

InterLayerInitializer()

The superclass of Initializers that initialize the weights between two layers.

IntraLayerInitializer()

The superclass of Initializers that initialize the weights within a layer.

Regular Initializers#

ZeroInit()

Zero initializer.

Constant([value])

Constant initializer.

OneInit([value])

One initializer.

Identity([value])

Returns the identity matrix.

Random Initializers#

Normal([mean, scale, seed])

Initialize weights with normal distribution.

Uniform([min_val, max_val, seed])

Initialize weights with uniform distribution.

VarianceScaling(scale, mode, distribution[, ...])

KaimingUniform([scale, mode, distribution, ...])

KaimingNormal([scale, mode, distribution, ...])

XavierUniform([scale, mode, distribution, ...])

XavierNormal([scale, mode, distribution, ...])

LecunUniform([scale, mode, distribution, ...])

LecunNormal([scale, mode, distribution, ...])

Orthogonal([scale, axis, seed])

Construct an initializer for uniformly distributed orthogonal matrices.

DeltaOrthogonal([scale, axis])

Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.

Decay Initializers#

GaussianDecay(sigma, max_w[, min_w, ...])

Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function.

DOGDecay(sigmas, max_ws[, min_w, ...])

Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons.

brainpy.inputs module#

section_input(values, durations[, dt, ...])

Format an input current with different sections.

constant_input(I_and_duration[, dt])

Format constant input in durations.

spike_input(sp_times, sp_lens, sp_sizes, ...)

Format current input like a series of short-time spikes.

ramp_input(c_start, c_end, duration[, ...])

Get the gradually changed input current.

wiener_process(duration[, dt, n, t_start, ...])

Stimulus sampled from a Wiener process, i.e. drawn from standard normal distribution N(0, sqrt(dt)).

ou_process(mean, sigma, tau, duration[, dt, ...])

Ornstein–Uhlenbeck input.

sinusoidal_input(amplitude, frequency, duration)

Sinusoidal input.

square_input(amplitude, frequency, duration)

Oscillatory square input.

brainpy.losses module#

Comparison#

cross_entropy_loss(predicts, targets[, ...])

This criterion combines LogSoftmax and NLLLoss` in one single class.

cross_entropy_sparse(predicts, targets)

Computes the softmax cross-entropy loss.

cross_entropy_sigmoid(predicts, targets)

Computes the sigmoid cross-entropy loss.

l1_loos(logits, targets[, reduction])

Creates a criterion that measures the mean absolute error (MAE) between each element in the logits \(x\) and targets \(y\).

l2_loss(predicts, targets)

Computes the L2 loss.

huber_loss(predicts, targets[, delta])

Huber loss.

mean_absolute_error(x, y[, axis, reduction])

Computes the mean absolute error between x and y.

mean_squared_error(predicts, targets[, ...])

Computes the mean squared error between x and y.

mean_squared_log_error(predicts, targets[, ...])

Computes the mean squared logarithmic error between y_true and y_pred.

binary_logistic_loss(predicts, targets)

Binary logistic loss.

multiclass_logistic_loss(label, logits)

Multiclass logistic loss.

sigmoid_binary_cross_entropy(logits, labels)

Computes sigmoid cross entropy given logits and multiple class labels.

softmax_cross_entropy(logits, labels)

Computes the softmax cross entropy between sets of logits and labels.

log_cosh_loss(predicts, targets)

Calculates the log-cosh loss for a set of predictions.

ctc_loss_with_forward_probs(logits, ...[, ...])

Computes CTC loss and CTC forward-probabilities.

ctc_loss(logits, logit_paddings, labels, ...)

Computes CTC loss.

Regularization#

l2_norm(x[, axis])

Computes the L2 loss.

mean_absolute(outputs[, axis])

Computes the mean absolute error between x and y.

mean_square(predicts[, axis])

log_cosh(errors)

Calculates the log-cosh loss for a set of predictions.

smooth_labels(labels, alpha)

Apply label smoothing.

brainpy.measure module#

cross_correlation(spikes, bin[, dt, numpy, ...])

Calculate cross correlation index between neurons.

voltage_fluctuation(potentials[, numpy, method])

Calculate neuronal synchronization via voltage variance.

matrix_correlation(x, y[, numpy])

Pearson correlation of the lower triagonal of two matrices.

weighted_correlation(x, y, w[, numpy])

Weighted Pearson correlation of two data series.

functional_connectivity(activities[, numpy])

Functional connectivity matrix of timeseries activities.

raster_plot(sp_matrix, times)

Get spike raster plot which displays the spiking activity of a group of neurons over time.

firing_rate(spikes, width[, dt, numpy])

Calculate the mean firing rate over in a neuron group.

unitary_LFP(times, spikes[, spike_type, ...])

A kernel-based method to calculate unitary local field potentials (uLFP) from a network of spiking neurons [1]_.

brainpy.optim module#

Optimizers#

Optimizer(lr[, train_vars, name])

Base Optimizer Class.

SGD(lr[, train_vars, weight_decay, name])

Stochastic gradient descent optimizer.

Momentum(lr[, train_vars, momentum, ...])

Momentum optimizer.

MomentumNesterov(lr[, train_vars, ...])

Nesterov accelerated gradient optimizer [2]_.

Adagrad(lr[, train_vars, weight_decay, ...])

Optimizer that implements the Adagrad algorithm.

Adadelta([lr, train_vars, weight_decay, ...])

Optimizer that implements the Adadelta algorithm.

RMSProp(lr[, train_vars, weight_decay, ...])

Optimizer that implements the RMSprop algorithm.

Adam(lr[, train_vars, beta1, beta2, eps, ...])

Optimizer that implements the Adam algorithm.

LARS(lr[, train_vars, momentum, ...])

Layer-wise adaptive rate scaling (LARS) optimizer [1]_.

Adan([lr, train_vars, betas, eps, ...])

Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.

AdamW(lr[, train_vars, beta1, beta2, eps, ...])

Adam with weight decay regularization [1]_.

Schedulers#

make_schedule(scalar_or_schedule)

Scheduler(lr[, last_epoch])

The learning rate scheduler.

Constant(lr[, last_epoch])

StepLR(lr, step_size[, gamma, last_epoch])

Decays the learning rate of each parameter group by gamma every step_size epochs.

MultiStepLR(lr, milestones[, gamma, last_epoch])

Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.

CosineAnnealingLR(lr, T_max[, eta_min, ...])

Set the learning rate of each parameter group using a cosine annealing schedule, where \(\eta_{max}\) is set to the initial lr and \(T_{cur}\) is the number of epochs since the last restart in SGDR:

CosineAnnealingWarmRestarts(lr, ...[, ...])

Set the learning rate of each parameter group using a cosine annealing

ExponentialLR(lr, gamma[, last_epoch])

Decays the learning rate of each parameter group by gamma every epoch.

ExponentialDecay(lr, decay_steps, decay_rate)

InverseTimeDecay(lr, decay_steps, decay_rate)

PolynomialDecay(lr, decay_steps, final_lr[, ...])

PiecewiseConstant(boundaries, values[, ...])

brainpy.running module#

jax_vectorize_map(func, arguments, num_parallel)

Perform a vectorized map of a function by using jax.vmap.

jax_parallelize_map(func, arguments, ...[, ...])

Perform a parallelized map of a function by using jax.pmap.

process_pool(func, all_params, num_process)

Run multiple models in multi-processes.

process_pool_lock(func, all_params, num_process)

Run multiple models in multi-processes with lock.

cpu_ordered_parallel(func, arguments[, ...])

Performs a parallel ordered map with a progress bar.

cpu_unordered_parallel(func, arguments[, ...])

Performs a parallel unordered map with a progress bar.

Release notes (brainpy)#

Note

All history release notes please see GitHub releases.

brainpy 2.2.x#

BrainPy 2.2.x is a complete re-design of the framework, tackling the shortcomings of brainpy 2.1.x generation, effectively bringing it to research needs and standards.

Version 2.2.1 (2022.09.09)#

This release fixes bugs found in the codebase and improves the usability and functions of BrainPy.

Bug fixes#
  1. Fix the bug of operator customization in brainpy.math.XLACustomOp and brainpy.math.register_op. Now, it supports operator customization by using NumPy and Numba interface. For instance,

import brainpy.math as bm

def abs_eval(events, indices, indptr, post_val, values):
      return post_val

def con_compute(outs, ins):
      post_val = outs
      events, indices, indptr, _, values = ins
      for i in range(events.size):
        if events[i]:
          for j in range(indptr[i], indptr[i + 1]):
            index = indices[j]
            old_value = post_val[index]
            post_val[index] = values + old_value

event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute)
  1. Fix the bug of brainpy.tools.DotDict. Now, it is compatible with the transformations of JAX. For instance,

import brainpy as bp
from jax import vmap

@vmap
def multiple_run(I):
  hh = bp.neurons.HH(1)
  runner = bp.dyn.DSRunner(hh, inputs=('input', I), numpy_mon_after_run=False)
  runner.run(100.)
  return runner.mon

mon = multiple_run(bp.math.arange(2, 10, 2))
New features#
  1. Add numpy operators brainpy.math.mat, brainpy.math.matrix, brainpy.math.asmatrix.

  2. Improve translation rules of brainpylib operators, improve its running speeds.

  3. Support DSView of DynamicalSystem instance. Now, it supports defining models with a slice view of a DS instance. For example,

import brainpy as bp
import brainpy.math as bm


class EINet_V2(bp.dyn.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    super(EINet_V2, self).__init__()

    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    self.N = bp.neurons.LIF(num_exc + num_inh,
                            V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                            method=method, V_initializer=bp.initialize.Normal(-55., 2.))

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc], post=self.N,
                                        conn=bp.connect.FixedProb(0.02),
                                        g_max=we, tau=5.,
                                        output=bp.synouts.COBA(E=0.),
                                        method=method)
    self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:], post=self.N,
                                        conn=bp.connect.FixedProb(0.02),
                                        g_max=wi, tau=10.,
                                        output=bp.synouts.COBA(E=-80.),
                                        method=method)

net = EINet_V2(scale=1., method='exp_auto')
# simulation
runner = bp.dyn.DSRunner(
    net,
    monitors={'spikes': net.N.spike},
    inputs=[(net.N.input, 20.)]
  )
runner.run(100.)

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)

Version 2.2.0 (2022.08.12)#

This release has provided important improvements for BrainPy, including usability, speed, functions, and others.

Backwards Incompatible changes#
  1. brainpy.nn module is no longer supported and has been removed since version 2.2.0. Instead, users should use brainpy.train module for the training of BP algorithms, online learning, or offline learning algorithms, and brainpy.algorithms module for online / offline training algorithms.

  2. The update() function for the model definition has been changed:

>>> # 2.1.x
>>>
>>> import brainpy as bp
>>>
>>> class SomeModel(bp.dyn.DynamicalSystem):
>>>      def __init__(self, ):
>>>            ......
>>>      def update(self, t, dt):
>>>           pass
>>> # 2.2.x
>>>
>>> import brainpy as bp
>>>
>>> class SomeModel(bp.dyn.DynamicalSystem):
>>>      def __init__(self, ):
>>>            ......
>>>      def update(self, tdi):
>>>           t, dt = tdi.t, tdi.dt
>>>           pass

where tdi can be defined with other names, like sha, to represent the shared argument across modules.

Deprecations#
  1. brainpy.dyn.xxx (neurons) and brainpy.dyn.xxx (synapse) are no longer supported. Please use brainpy.neurons, brainpy.synapses modules.

  2. brainpy.running.monitor has been removed.

  3. brainpy.nn module has been removed.

New features#
  1. brainpy.math.Variable receives a batch_axis setting to represent the batch axis of the data.

>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros((1, 4, 5)), batch_axis=0)
>>> a.value = bm.zeros((2, 4, 5))  # success
>>> a.value = bm.zeros((1, 2, 5))  # failed
MathError: The shape of the original data is (2, 4, 5), while we got (1, 2, 5) with batch_axis=0.
  1. brainpy.train provides brainpy.train.BPTT for back-propagation algorithms, brainpy.train.Onlinetrainer for online training algorithms, brainpy.train.OfflineTrainer for offline training algorithms.

  2. brainpy.Base class supports _excluded_vars setting to ignore variables when retrieving variables by using Base.vars() method.

>>> class OurModel(bp.Base):
>>>     _excluded_vars = ('a', 'b')
>>>     def __init__(self):
>>>         super(OurModel, self).__init__()
>>>         self.a = bm.Variable(bm.zeros(10))
>>>         self.b = bm.Variable(bm.ones(20))
>>>         self.c = bm.Variable(bm.random.random(10))
>>>
>>> model = OurModel()
>>> model.vars().keys()
dict_keys(['OurModel0.c'])
  1. brainpy.analysis.SlowPointFinder supports directly analyzing an instance of brainpy.dyn.DynamicalSystem.

>>> hh = bp.neurons.HH(1)
>>> finder = bp.analysis.SlowPointFinder(hh, target_vars={'V': hh.V, 'm': hh.m, 'h': hh.h, 'n': hh.n})
  1. brainpy.datasets supports MNIST, FashionMNIST, and other datasets.

  2. Supports defining conductance-based neuron models``.

>>> class HH(bp.dyn.CondNeuGroup):
>>>   def __init__(self, size):
>>>     super(HH, self).__init__(size)
>>>
>>>     self.INa = channels.INa_HH1952(size, )
>>>     self.IK = channels.IK_HH1952(size, )
>>>     self.IL = channels.IL(size, E=-54.387, g_max=0.03)
  1. brainpy.layers module provides commonly used models for DNN and reservoir computing.

  2. Support composable definition of synaptic models by using TwoEndConn, SynOut, SynSTP and SynLTP.

>>> bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob),
>>>                      g_max=0.03 / scale, tau=5,
>>>                      output=bp.synouts.COBA(E=0.),
>>>                      stp=bp.synplast.STD())
  1. Provide commonly used surrogate gradient function for spiking generation, including

    • brainpy.math.spike_with_sigmoid_grad

    • brainpy.math.spike_with_linear_grad

    • brainpy.math.spike_with_gaussian_grad

    • brainpy.math.spike_with_mg_grad

  2. Provide shortcuts for GPU memory management via brainpy.math.disable_gpu_memory_preallocation(), and brainpy.math.clear_buffer_memory().

What’s Changed#

Full Changelog: V2.1.12…V2.2.0

brainpy 2.1.x#

Version 2.1.12 (2022.05.17)#

Highlights#

This release is excellent. We have made important improvements.

  1. We provide dozens of random sampling in NumPy which are not supportted in JAX, such as brainpy.math.random.bernoulli, brainpy.math.random.lognormal, brainpy.math.random.binomial, brainpy.math.random.chisquare, brainpy.math.random.dirichlet, brainpy.math.random.geometric, brainpy.math.random.f, brainpy.math.random.hypergeometric, brainpy.math.random.logseries, brainpy.math.random.multinomial, brainpy.math.random.multivariate_normal, brainpy.math.random.negative_binomial, brainpy.math.random.noncentral_chisquare, brainpy.math.random.noncentral_f, brainpy.math.random.power, brainpy.math.random.rayleigh, brainpy.math.random.triangular, brainpy.math.random.vonmises, brainpy.math.random.wald, brainpy.math.random.weibull

  2. make efficient checking on numerical values. Instead of direct id_tap() checking which has large overhead, currently brainpy.tools.check_erro_in_jit() is highly efficient.

  3. Fix JaxArray operator errors on None

  4. improve oo-to-function transformation speeds

  5. io works: .save_states() and .load_states()

What’s Changed#

Full Changelog: V2.1.11…V2.1.12

Version 2.1.11 (2022.05.15)#

What’s Changed#

Full Changelog: V2.1.10…V2.1.11

Version 2.1.10 (2022.05.05)#

What’s Changed#

Full Changelog: V2.1.8…V2.1.10

Version 2.1.8 (2022.04.26)#

What’s Changed#

Full Changelog: V2.1.7…V2.1.8

Version 2.1.7 (2022.04.22)#

What’s Changed#

Full Changelog: V2.1.5…V2.1.7

Version 2.1.5 (2022.04.18)#

What’s Changed#

Full Changelog: V2.1.4…V2.1.5

Version 2.1.4 (2022.04.04)#

What’s Changed#

Full Changelog: V2.1.3…V2.1.4

Version 2.1.3 (2022.03.27)#

This release improves the functionality and usability of BrainPy. Core changes include

  • support customization of low-level operators by using Numba

  • fix bugs

What’s Changed#

Full Changelog : V2.1.2…V2.1.3

Version 2.1.2 (2022.03.23)#

This release improves the functionality and usability of BrainPy. Core changes include

  • support rate-based whole-brain modeling

  • add more neuron models, including rate neurons/synapses

  • support Python 3.10

  • improve delays etc. APIs

What’s Changed#

Full Changelog: V2.1.1…V2.1.2

Version 2.1.1 (2022.03.18)#

This release continues to update the functionality of BrainPy. Core changes include

  • numerical solvers for fractional differential equations

  • more standard brainpy.nn interfaces

New Features#
  • Numerical solvers for fractional differential equations
    • brainpy.fde.CaputoEuler

    • brainpy.fde.CaputoL1Schema

    • brainpy.fde.GLShortMemory

  • Fractional neuron models
    • brainpy.dyn.FractionalFHR

    • brainpy.dyn.FractionalIzhikevich

  • support shared_kwargs in RNNTrainer and RNNRunner

Version 2.1.0 (2022.03.14)#

Highlights#

We are excited to announce the release of BrainPy 2.1.0. This release is composed of nearly 270 commits since 2.0.2, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.1.0 updates are focused on improving usability, functionality, and stability of BrainPy. Highlights of version 2.1.0 include:

  • New module brainpy.dyn for dynamics building and simulation. It is composed of many neuron models, synapse models, and others.

  • New module brainpy.nn for neural network building and training. It supports to define reservoir models, artificial neural networks, ridge regression training, and back-propagation through time training.

  • New module brainpy.datasets for convenient dataset construction and initialization.

  • New module brainpy.integrators.dde for numerical integration of delay differential equations.

  • Add more numpy-like operators in brainpy.math module.

  • Add automatic continuous integration on Linux, Windows, and MacOS platforms.

  • Fully update brainpy documentation.

  • Fix bugs on brainpy.analysis and brainpy.math.autograd

Incompatible changes#
  • Remove brainpy.math.numpy module.

  • Remove numba requirements

  • Remove matplotlib requirements

  • Remove steps in brainpy.dyn.DynamicalSystem

  • Remove travis CI

New Features#
  • brainpy.ddeint for numerical integration of delay differential equations, the supported methods include:

    • Euler

    • MidPoint

    • Heun2

    • Ralston2

    • RK2

    • RK3

    • Heun3

    • Ralston3

    • SSPRK3

    • RK4

    • Ralston4

    • RK4Rule38

  • set default int/float/complex types
    • brainpy.math.set_dfloat()

    • brainpy.math.set_dint()

    • brainpy.math.set_dcomplex()

  • Delay variables
    • brainpy.math.FixedLenDelay

    • brainpy.math.NeutralDelay

  • Dedicated operators
    • brainpy.math.sparse_matmul()

  • More numpy-like operators

  • Neural network building brainpy.nn

  • Dynamics model building and simulation brainpy.dyn

Version 2.0.2 (2022.02.11)#

There are important updates by Chaoming Wang in BrainPy 2.0.2.

  • provide pre2post_event_prod operator

  • support array creation from a list/tuple of JaxArray in brainpy.math.asarray and brainpy.math.array

  • update brainpy.ConstantDelay, add .latest and .oldest attributes

  • add brainpy.IntegratorRunner support for efficient simulation of brainpy integrators

  • support auto finding of RandomState when JIT SDE integrators

  • fix bugs in SDE exponential_euler method

  • move parallel running APIs into brainpy.simulation

  • add brainpy.math.syn2post_mean, brainpy.math.syn2post_softmax, brainpy.math.pre2post_mean and brainpy.math.pre2post_softmax operators

Version 2.0.1 (2022.01.31)#

Today we release BrainPy 2.0.1. This release is composed of over 70 commits since 2.0.0, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.0.0 updates are focused on improving documentation and operators. Core changes include:

  • Improve brainpylib operators

  • Complete documentation for programming system

  • Add more numpy APIs

  • Add jaxfwd in autograd module

  • And other changes

Version 2.0.0.1 (2022.01.05)#

  • Add progress bar in brainpy.StructRunner

Version 2.0.0 (2021.12.31)#

Start a new version of BrainPy.

Highlight#

We are excited to announce the release of BrainPy 2.0.0. This release is composed of over 260 commits since 1.1.7, made by Chaoming Wang, Xiaoyu Chen, and Tianqiu Zhang .

BrainPy 2.0.0 updates are focused on improving performance, usability and consistence of BrainPy. All the computations are migrated into JAX. Model building, simulation, training and analysis are all based on JAX. Highlights of version 2.0.0 include:

  • brainpylib are provided to dedicated operators for brain dynamics programming

  • Connection APIs in brainpy.conn module are more efficient.

  • Update analysis tools for low-dimensional and high-dimensional systems in brainpy.analysis module.

  • Support more general Exponential Euler methods based on automatic differentiation.

  • Improve the usability and consistence of brainpy.math module.

  • Remove JIT compilation based on Numba.

  • Separate brain building with brain simulation.

Incompatible changes#
  • remove brainpy.math.use_backend()

  • remove brainpy.math.numpy module

  • no longer support .run() in brainpy.DynamicalSystem (see New Features)

  • remove brainpy.analysis.PhasePlane (see New Features)

  • remove brainpy.analysis.Bifurcation (see New Features)

  • remove brainpy.analysis.FastSlowBifurcation (see New Features)

New Features#
  • Exponential Euler method based on automatic differentiation
    • brainpy.ode.ExpEulerAuto

  • Numerical optimization based low-dimensional analyzers:
    • brainpy.analysis.PhasePlane1D

    • brainpy.analysis.PhasePlane2D

    • brainpy.analysis.Bifurcation1D

    • brainpy.analysis.Bifurcation2D

    • brainpy.analysis.FastSlow1D

    • brainpy.analysis.FastSlow2D

  • Numerical optimization based high-dimensional analyzer:
    • brainpy.analysis.SlowPointFinder

  • Dedicated operators in brainpy.math module:
    • brainpy.math.pre2post_event_sum

    • brainpy.math.pre2post_sum

    • brainpy.math.pre2post_prod

    • brainpy.math.pre2post_max

    • brainpy.math.pre2post_min

    • brainpy.math.pre2syn

    • brainpy.math.syn2post

    • brainpy.math.syn2post_prod

    • brainpy.math.syn2post_max

    • brainpy.math.syn2post_min

  • Conversion APIs in brainpy.math module:
    • brainpy.math.as_device_array()

    • brainpy.math.as_variable()

    • brainpy.math.as_jaxarray()

  • New autograd APIs in brainpy.math module:
    • brainpy.math.vector_grad()

  • Simulation runners:
    • brainpy.ReportRunner

    • brainpy.StructRunner

    • brainpy.NumpyRunner

  • Commonly used models in brainpy.models module
    • brainpy.models.LIF

    • brainpy.models.Izhikevich

    • brainpy.models.AdExIF

    • brainpy.models.SpikeTimeInput

    • brainpy.models.PoissonInput

    • brainpy.models.DeltaSynapse

    • brainpy.models.ExpCUBA

    • brainpy.models.ExpCOBA

    • brainpy.models.AMPA

    • brainpy.models.GABAa

  • Naming cache clean: brainpy.clear_name_cache

  • add safe in-place operations of update() method and .value assignment for JaxArray

Documentation#
  • Complete tutorials for quickstart

  • Complete tutorials for dynamics building

  • Complete tutorials for dynamics simulation

  • Complete tutorials for dynamics training

  • Complete tutorials for dynamics analysis

  • Complete tutorials for API documentation

brainpy 1.1.x#

If you are using brainpy==1.x, you can find documentation, examples, and models through the following links:

Version 1.1.7 (2021.12.13)#

  • fix bugs on numpy_array() conversion in brainpy.math.utils module

Version 1.1.5 (2021.11.17)#

API changes:

  • fix bugs on ndarray import in brainpy.base.function.py

  • convenient ‘get_param’ interface brainpy.simulation.layers

  • add more weight initialization methods

Doc changes:

  • add more examples in README

Version 1.1.4#

API changes:

  • add .struct_run() in DynamicalSystem

  • add numpy_array() conversion in brainpy.math.utils module

  • add Adagrad, Adadelta, RMSProp optimizers

  • remove setting methods in brainpy.math.jax module

  • remove import jax in brainpy.__init__.py and enable jax setting, including

    • enable_x64()

    • set_platform()

    • set_host_device_count()

  • enable b=None as no bias in brainpy.simulation.layers

  • set int_ and float_ as default 32 bits

  • remove dtype setting in Initializer constructor

Doc changes:

  • add optimizer in “Math Foundation”

  • add dynamics training docs

  • improve others

Version 1.1.3#

  • fix bugs of JAX parallel API imports

  • fix bugs of post_slice structure construction

  • update docs

Version 1.1.2#

  • add pre2syn and syn2post operators

  • add verbose and check option to Base.load_states()

  • fix bugs on JIT DynamicalSystem (numpy backend)

Version 1.1.1#

  • fix bugs on symbolic analysis: model trajectory

  • change absolute access in the variable saving and loading to the relative access

  • add UnexpectedTracerError hints in JAX transformation functions

Version 1.1.0 (2021.11.08)#

This package releases a new version of BrainPy.

Highlights of core changes:

math module#
  • support numpy backend

  • support JAX backend

  • support jit, vmap and pmap on class objects on JAX backend

  • support grad, jacobian, hessian on class objects on JAX backend

  • support make_loop, make_while, and make_cond on JAX backend

  • support jit (based on numba) on class objects on numpy backend

  • unified numpy-like ndarray operation APIs

  • numpy-like random sampling APIs

  • FFT functions

  • gradient descent optimizers

  • activation functions

  • loss function

  • backend settings

base module#
  • Base for whole Version ecosystem

  • Function to wrap functions

  • Collector and TensorCollector to collect variables, integrators, nodes and others

integrators module#
  • class integrators for ODE numerical methods

  • class integrators for SDE numerical methods

simulation module#
  • support modular and composable programming

  • support multi-scale modeling

  • support large-scale modeling

  • support simulation on GPUs

  • fix bugs on firing_rate()

  • remove _i in update() function, replace _i with _dt, meaning the dynamic system has the canonic equation form of \(dx/dt = f(x, t, dt)\)

  • reimplement the input_step and monitor_step in a more intuitive way

  • support to set dt in the single object level (i.e., single instance of DynamicSystem)

  • common used DNN layers

  • weight initializations

  • refine synaptic connections

brainpy 1.0.x#

Version 1.0.3 (2021.08.18)#

Fix bugs on

  • firing rate measurement

  • stability analysis

Version 1.0.2#

This release continues to improve the user-friendliness.

Highlights of core changes:

  • Remove support for Numba-CUDA backend

  • Super initialization super(XXX, self).__init__() can be done at anywhere (not required to add at the bottom of the __init__() function).

  • Add the output message of the step function running error.

  • More powerful support for Monitoring

  • More powerful support for running order scheduling

  • Remove unsqueeze() and squeeze() operations in brainpy.ops

  • Add reshape() operation in brainpy.ops

  • Improve docs for numerical solvers

  • Improve tests for numerical solvers

  • Add keywords checking in ODE numerical solvers

  • Add more unified operations in brainpy.ops

  • Support “@every” in steps and monitor functions

  • Fix ODE solver bugs for class bounded function

  • Add build phase in Monitor

Version 1.0.1#

  • Fix bugs

Version 1.0.0#

  • NEW VERSION OF BRAINPY

  • Change the coding style into the object-oriented programming

  • Systematically improve the documentation

brainpy 0.x#

Version 0.3.5#

  • Add ‘timeout’ in sympy solver in neuron dynamics analysis

  • Reconstruct and generalize phase plane analysis

  • Generalize the repeat mode of Network to different running duration between two runs

  • Update benchmarks

  • Update detailed documentation

Version 0.3.1#

  • Add a more flexible way for NeuState/SynState initialization

  • Fix bugs of “is_multi_return”

  • Add “hand_overs”, “requires” and “satisfies”.

  • Update documentation

  • Auto-transform range to numba.prange

  • Support _obj_i, _pre_i, _post_i for more flexible operation in scalar-based models

Version 0.3.0#

Computation API#
  • Rename “brainpy.numpy” to “brainpy.backend”

  • Delete “pytorch”, “tensorflow” backends

  • Add “numba” requirement

  • Add GPU support

Profile setting#
  • Delete “backend” profile setting, add “jit”

Core systems#
  • Delete “autopepe8” requirement

  • Delete the format code prefix

  • Change keywords “_t_, _dt_, _i_” to “_t, _dt, _i”

  • Change the “ST” declaration out of “requires”

  • Add “repeat” mode run in Network

  • Change “vector-based” to “mode” in NeuType and SynType definition

Package installation#
  • Remove “pypi” installation, installation now only rely on “conda”

Version 0.2.4#

API changes#
  • Fix bugs

Version 0.2.3#

API changes#
  • Add “animate_1D” in visualization module

  • Add “PoissonInput”, “SpikeTimeInput” and “FreqInput” in inputs module

  • Update phase_portrait_analyzer.py

Models and examples#
  • Add CANN examples

Version 0.2.2#

API changes#
  • Redesign visualization

  • Redesign connectivity

  • Update docs

Version 0.2.1#

API changes#
  • Fix bugs in numba import

  • Fix bugs in numpy mode with scalar model

Version 0.2.0#

API changes#
  • For computation: numpy, numba

  • For model definition: NeuType, SynConn

  • For model running: Network, NeuGroup, SynConn, Runner

  • For numerical integration: integrate, Integrator, DiffEquation

  • For connectivity: One2One, All2All, GridFour, grid_four, GridEight, grid_eight, GridN, FixedPostNum, FixedPreNum, FixedProb, GaussianProb, GaussianWeight, DOG

  • For visualization: plot_value, plot_potential, plot_raster, animation_potential

  • For measurement: cross_correlation, voltage_fluctuation, raster_plot, firing_rate

  • For inputs: constant_current, spike_current, ramp_current.

Models and examples#
  • Neuron models: HH model, LIF model, Izhikevich model

  • Synapse models: AMPA, GABA, NMDA, STP, GapJunction

  • Network models: gamma oscillation

Indices and tables#