Building Training Models#

In this section, we are going to talk about how to build models for training.

import brainpy as bp
import brainpy.math as bm

Use built-in models#

brainpy.dyn.DynamicalSystem provided in BrainPy can be used for model training.

mode settings#

Some built-in models have implemented the training interface for their training. Users can instantiate these models by providing the parameter for training model customization.

For example, brainpy.neurons.LIF is a model commonly used in computational simulation, but it can also be used in training.

# Instantiate a LIF model for simulation

lif = bp.neurons.LIF(1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Instantiate a LIF model for training.
# In this mode, the model implement variables and functions
# compatible with BrainPy's training interface.

lif = bp.neurons.LIF(1,

But some build-in models does not support training.

bp.layers.NVAR(1, 1,
NotImplementedError                       Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 bp.layers.NVAR(1, 1,

File D:\codes\projects\brainpy-chaoming0625\brainpy\dyn\layers\, in NVAR.__init__(self, num_in, delay, order, stride, constant, mode, name)
     65 def __init__(
     66     self,
     67     num_in,
     73     name: str = None,
     74 ):
     75   super(NVAR, self).__init__(mode=mode, name=name)
---> 76   check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)
     78   # parameters
     79   order = tuple() if order is None else order

File D:\codes\projects\brainpy-chaoming0625\brainpy\, in check(mode, supported_modes, name)
     64 checking = np.asarray([issubclass(smode, type(mode)) for smode in supported_modes])
     65 if not np.isin(True, checking):
---> 66   raise NotImplementedError(f"{name} does not support {mode}. We only support "
     67                             f"{', '.join([mode.__name__ for mode in supported_modes])}. ")

NotImplementedError: NVAR does not support TrainingMode. We only support BatchingMode, NormalMode. 

The mode can be used to control the weight types. Let’s take a synaptic model for another example. For a non-trainable dense layer, the weights and bias are JaxArray instances.

l = bp.layers.Dense(3, 4, mode=bp.modes.batching)

JaxArray([[-0.2552617 ,  0.40152806, -0.75552243,  0.5301098 ],
          [ 0.11408956, -0.0063706 ,  0.26513448, -0.12788086],
          [ 0.07695759,  0.4182222 ,  0.80788815, -0.0341561 ]],            dtype=float32)
l = bp.layers.Dense(3, 4,

TrainVar([[ 0.13648991, -1.1017411 ,  0.04438929, -0.03525464],
          [-0.1966483 ,  0.42640603,  0.18005033,  0.75901693],
          [-0.46449846,  0.75061077,  1.0296121 , -0.58486235]],            dtype=float32)

Moreover, for some recurrent models, e.g., LSTM or GRU, the state can be set to be trainable or not trainable by train_state argument. When setting train_state=True for the recurrent instance, a new attribute .state2train will be created.

rnn = bp.layers.VanillaRNN(1, 3, train_state=True)

TrainVar([0., 0., 0.], dtype=float32)

Note the difference between the .state2train and the original .state:

  1. .state2train has no batch axis.

  2. When using node.reset_state() function, all values in the .state will be filled with .state2train.

Variable([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=float32)

Naming a node#

For convenience, you can name a layer by specifying the name keyword argument:

bp.layers.Dense(128, 100, name='hidden_layer')
Dense(name=hidden_layer, num_in=128, num_out=100, mode=TrainingMode)

Initializing parameters#

Many models have their parameters. We can set the parameter of a model with the following methods.

  • Arrays

If an array is provided, this is used unchanged as the parameter variable. For example:

l = bp.layers.Dense(10, 50, W_initializer=bm.random.normal(0, 0.01, size=(10, 50)))

(10, 50)
  • Callable function

If a callable function (which receives a shape argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:

def init(shape):
    return bm.random.random(shape)

l = bp.layers.Dense(20, 30, W_initializer=init)

(20, 30)
  • Instance of brainpy.init.Initializer

If a brainpy.init.Initializer instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:

l = bp.layers.Dense(20, 30, W_initializer=bp.init.Normal(0.01))

(20, 30)

The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.init for more information).

  • None parameter

Some types of parameter variables can also be set to None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:

l = bp.layers.Dense(20, 100, b_initializer=None)


Customize your models#

Customizing your training models is simple. You just need to subclass brainpy.dyn.DynamicalSystem, and implement its update() and reset_state() functions.

Here, we demonstrate the model customization using two examples. The first is a recurrent layer.

class RecurrentLayer(bp.dyn.DynamicalSystem):
    def __init__(self, num_in, num_out):
        super(RecurrentLayer, self).__init__()

        # define parameters
        self.num_in = num_in
        self.num_out = num_out

        # define variables
        self.state = bm.Variable(bm.zeros(1, num_out), batch_axis=0)

        # define weights = bm.TrainVar(bm.random.normal(0., 1./num_in ** 0.5, size=(num_in, num_out)))
        self.wrec = bm.TrainVar(bm.random.normal(0., 1./num_out ** 0.5, size=(num_out, num_out)))

    def reset_state(self, batch_size):
        # this function defines how to reset the mode states
        self.state.value = bm.zeros((batch_size, self.num_out))

    def update(self, sha, x):
        # this function defined how the model update its state and produce its output
        out =, +, self.wrec)
        self.state.value = bm.tanh(out)
        return self.state.value

This simple example illustrates many features essential for a training model. reset_state() function defines how to reset model states, which will be called at the first time step; update() function defines how the model states are evolving, which will be called at every time step.

Another example is the dropout layer, which can be useful to demonstrate how to define a model with multiple behaviours.

class Dropout(bp.dyn.DynamicalSystem):
  def __init__(self, prob: float, seed: int = None, name: str = None):
    super(Dropout, self).__init__(name=name)
    self.prob = prob
    self.rng = bm.random.RandomState(seed=seed)

  def update(self, sha, x):
    if sha.get('fit', True):
      keep_mask = self.rng.bernoulli(self.prob, x.shape)
      return bm.where(keep_mask, x / self.prob, 0.)
      return x

Here, the model makes different outputs according to the different values of a shared parameter fit.

You can define your own shared parameters, and then provide their shared parameters when calling the trainer objects (see the following section).

Examples of training models#

In the following, we illustrate several examples to build a trainable neural network model.

Artificial neural networks#

BrainPy provides neural network layers which can be useful to define artificial neural networks.

Here, let’s define a deep RNN model.

class DeepRNN(bp.dyn.DynamicalSystem):
    def __init__(self, num_in, num_recs, num_out):
        super(DeepRNN, self).__init__()

        self.l1 = bp.layers.LSTM(num_in, num_recs[0])
        self.d1 = bp.layers.Dropout(0.2)
        self.l2 = bp.layers.LSTM(num_recs[0], num_recs[1])
        self.d2 = bp.layers.Dropout(0.2)
        self.l3 = bp.layers.LSTM(num_recs[1], num_recs[2])
        self.d3 = bp.layers.Dropout(0.2)
        self.l4 = bp.layers.LSTM(num_recs[2], num_recs[3])
        self.d4 = bp.layers.Dropout(0.2)
        self.lout = bp.layers.Dense(num_recs[3], num_out)

    def update(self, sha, x):
        x = self.d1(sha, self.l1(sha, x))
        x = self.d2(sha, self.l2(sha, x))
        x = self.d3(sha, self.l3(sha, x))
        x = self.d4(sha, self.l4(sha, x))
        return self.lout(sha, x)

Note here the difference of the model building from PyTorch is that the first argument in update() function should be the shared parameters sha (i.e., these parameters are shared across all models, like the time t, the running index i, and the model running phase fit). Then other individual arguments can all be customized by users. The details of the model definition specification can be seen in ????

Moreover, it is worthy to note that this model only defines the one step updating rule of how the model evolves according to the input x.

Reservoir computing models#

In this example, we define a reservoir computing model called next generation reservoir computing by using the built-in models provided in BrainPy.

class NGRC(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_out):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5,
    self.o = bp.layers.Dense(self.r.num_out, num_out,

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))

In the above model, brainpy.layers.NVAR is a nonlinear vector autoregression machine, which does not have the training features. Therefore, we define its mode as batching mode. On the contrary, brainpy.layers.Dense has the trainable weights for model training.

Spiking Neural Networks#

Building trainable spiking neural networks in BrainPy is also a piece of cake. We provided commonly used spiking models for traditional dynamics simulation. But most of them can be used for training too.

In the following, we provide an implementation of spiking neural networks in (Neftci, Mostafa, & Zenke, 2019) for surrogate gradient learning.

class SNN(bp.dyn.Network):
  def __init__(self, num_in, num_rec, num_out):
    super(SNN, self).__init__()

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in,
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.,
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5,

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,

  def update(self, tdi, spike):
    self.i2r(tdi, spike)
    return self.o.V.value

Note here the mode in all models are specified as brainpy.modes.TrainingMode.