Training a Brain Dynamics Model#

@Chaoming Wang

In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy provides various interfaces to help users train dynamical systems.

import brainpy as bp
import brainpy.math as bm

bm.enable_x64()

# bm.set_platform('cpu')
import matplotlib.pyplot as plt

Training a reservoir network model#

For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.

(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).

The difference between the two models is illustrated in the following figure.

(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.

Here, let’s implement a next generation reservoir model to predict the chaotic time series, named as Lorenz attractor. Particularly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.

dt = 0.01
data = bp.datasets.lorenz_series(100, dt=dt)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['x'].flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['y'].flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data['ts']), bm.as_numpy(data['z'].flatten()))
plt.ylabel('z')
plt.show()
../_images/9492b5db560f9b93d40950076a5228efc4ba5a01eab3db249b6e97f4843a9771.png

Let’s first create a function to get the data.

def get_subset(data, start, end):
    res = {'x': data['x'][start: end],
           'y': data['y'][start: end],
           'z': data['z'][start: end]}
    res = bm.hstack([res['x'], res['y'], res['z']])
    return res.reshape((1, ) + res.shape)

To accomplish this task, we implement a next-generation reservoir model of 4 delay history information with stride of 5, and their quadratic polynomial monomials, same as (Gauthier, et. al., Nature Communications, 2021).

class NGRC(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_out):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5, mode=bp.modes.batching)
    self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bp.modes.training)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    # other arguments like "x" can be customized by users.
    return self.o(sha, self.r(sha, x))
model = NGRC(num_in=3, num_out=3)

Moreover, we use Ridge Regression method to train the model.

trainer = bp.train.RidgeTrainer(model, alpha=1e-6)

We warm-up the network with 20 ms.

warmup_data = get_subset(data, 0, int(20/dt))

outs = trainer.predict(warmup_data)

# outputs should be an array with the shape of
# (num_batch, num_time, num_out)
outs.shape
(1, 2000, 3)

The training data is the time series from 20 ms to 80 ms. We want the network has the ability to forecast 1 time step ahead.

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)

trainer.fit([x_train, y_train])
JaxArray([[[10.42193545, 33.57694113, 21.64385191],
           [10.00120939, 32.31907872, 20.65487474],
           [ 9.57161603, 31.17880345, 19.74202592],
           ...,
           [ 9.36649357, 32.98189915, 20.23057405],
           [ 8.72150015, 32.16000653, 19.52957227],
           [ 8.00738287, 31.46370921, 18.94740094]]], dtype=float64)

Then we test the trained network with the next 20 ms.

x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))

predictions = trainer.predict(x_test)

bp.losses.mean_squared_error(y_test, predictions)
DeviceArray(2.27040876e-09, dtype=float64)
def plot_difference(truths, predictions):
    truths = bm.as_numpy(truths)
    predictions = bm.as_numpy(predictions)

    plt.subplot(311)
    plt.plot(truths[0, :, 0], label='Ground Truth')
    plt.plot(predictions[0, :, 0], label='Prediction')
    plt.ylabel('x')
    plt.legend()
    plt.subplot(312)
    plt.plot(truths[0, :, 1], label='Ground Truth')
    plt.plot(predictions[0, :, 1], label='Prediction')
    plt.ylabel('y')
    plt.legend()
    plt.subplot(313)
    plt.plot(truths[0, :, 2], label='Ground Truth')
    plt.plot(predictions[0, :, 2], label='Prediction')
    plt.ylabel('z')
    plt.legend()
    plt.show()
plot_difference(y_test, predictions)
../_images/30677c33f682a37b970f3dc938113612deb1efea80d77917c59e52e45ab422d8.png

We can make the task harder to forecast 10 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
../_images/4c80cfd217db64ba31376207eda56a81dad68d2c913ce88f94b704d1aff1ba1d.png

Or forecast 100 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
_ = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
../_images/0c108450c2ad8960bca2d4322c30e68af5390f2358fc9f4a73784eaf4d437291.png

As you see, forecasting larger time step makes the learning more difficult.

Training an artificial recurrent network#

In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.train.BPTT interface.

Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we have a time series of noise data,

noises = bm.random.normal(0, 0.2, size=10)

plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()
../_images/21b31a3d18fcc3f86ffcb30b6bf71b1f1ebd5037d68f8cbab272631a9087dae1.png

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt:

dt = 0.1
integrals = bm.cumsum(noises) * dt

plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()
../_images/546e11e2d821c6dd61a251be34a1c9080087e9b6ef95a2d5ec3b9495c7b4a5d6.png

Here, we first define a task which generates the input data and the target integration results.

from functools import partial

dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128


@partial(bm.jit,
         dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
         static_argnames=['batch_size'])
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
  # Create the white noise input
  sample = bm.random.normal(size=(batch_size, 1, 1))
  bias = mean * 2.0 * (sample - 0.5)
  samples = bm.random.normal(size=(batch_size, num_step, 1))
  noise_t = scale / dt ** 0.5 * samples
  inputs = bias + noise_t
  targets = bm.cumsum(inputs, axis=1)
  return inputs, targets


def train_data():
  for _ in range(100):
    yield build_inputs_and_targets(batch_size=num_batch)

Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True for the used VanillaRNN instance.

class RNN(bp.dyn.DynamicalSystem):
  def __init__(self, num_in, num_hidden):
    super(RNN, self).__init__()
    self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
    self.out = bp.layers.Dense(num_hidden, 1)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    return self.out(sha, self.rnn(sha, x))


model = RNN(1, 100)

brainpy.nn.BPTT trainer receives a loss function setting, and an optimizer setting. Loss function can be selected from the brainpy.losses module, or it can be a callable function receives (predictions, targets) argument. Optimizer setting must be an instance of brainpy.optim.Optimizer.

Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.

# define loss function
def loss(predictions, targets, l2_reg=2e-4):
    mse = bp.losses.mean_squared_error(predictions, targets)
    l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
    return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.train.BPTT(model,
                        loss_fun=loss,
                        optimizer=opt)
# train the model
trainer.fit(train_data,
            batch_size=num_batch,
            num_epoch=30,
            num_report=500)
Train 500 steps, use 6.2379 s, train loss 0.02201
Train 1000 steps, use 5.1785 s, train loss 0.02029
Train 1500 steps, use 4.8608 s, train loss 0.01913
Train 2000 steps, use 4.7570 s, train loss 0.01809
Train 2500 steps, use 4.7750 s, train loss 0.0172
Train 3000 steps, use 4.7693 s, train loss 0.01643

The training losses is recorded in the .train_losses attribute.

plt.figure(figsize=(8, 3))
plt.plot(trainer.train_losses.numpy())
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()
../_images/5d1860a4a7f7768a721bca0be01fb453992dd75b694877dd9d3a8d627f3cfad6.png

Finally, let’s try the trained network, and test whether it can generate the correct integration results.

model.reset_state(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()
../_images/66e56fffb604a6ce4ff336d19c9d9b8c9aeeedb8cfdcef4fab0c3484aa71ee79.png

Training a spiking neural network#

BrainPy also supports to train spiking neural networks.

In the following, we demonstrate how to use back-propagation algorithms to train spiking neurons with a simple example.

Our model is a simple three layer model:

  • an input layer

  • a LIF layer

  • a readout layer

The synaptic connection between each layer is the Exponenetial synapse model.

class SNN(bp.dyn.Network):
  def __init__(self, num_in, num_rec, num_out):
    super(SNN, self).__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in, mode=bp.modes.training)
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.,
                            mode=bp.modes.training) # note here the "mode" should be "training"
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5, mode=bp.modes.training)

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=20.),
                                       mode=bp.modes.training)
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(),
                                       output=bp.synouts.CUBA(), tau=10.,
                                       g_max=bp.init.KaimingNormal(scale=20.),
                                       mode=bp.modes.training)

  def update(self, tdi, spike):
    self.i2r(tdi, spike)
    self.r2o(tdi)
    self.r(tdi)
    self.o(tdi)
    return self.o.V.value
net = SNN(100, 10, 2)  # out task is a two label classification task

We try to use this simple task to classify a random spiking data into two classes.

num_step = 2000
num_sample = 256
freq = 5  # Hz
mask = bm.random.rand(num_sample, num_step, net.num_in)
x_data = bm.zeros((num_sample, num_step, net.num_in))
x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.dftype())

Same as the training of artificial recurrent neural networks, we use Adam optimizer and cross entropy loss to train the model.

opt = bp.optim.Adam(lr=2e-3)

def loss(predicts, targets):
  return bp.losses.cross_entropy_loss(bm.max(predicts, axis=1), targets)


trainer = bp.train.BPTT(net,
                        loss_fun=loss,
                        optimizer=opt)
trainer.fit([x_data, y_data],
            batch_size=num_sample,
            num_epoch=1500)
Train 100 steps, use 34.1163 s, train loss 0.61404
Train 200 steps, use 33.5027 s, train loss 0.51463
Train 300 steps, use 33.4973 s, train loss 0.38637
Train 400 steps, use 33.3690 s, train loss 0.30086
Train 500 steps, use 33.8671 s, train loss 0.23846
Train 600 steps, use 33.8336 s, train loss 0.18554
Train 700 steps, use 34.2268 s, train loss 0.15962
Train 800 steps, use 35.0706 s, train loss 0.11911
Train 900 steps, use 34.7535 s, train loss 0.09325
Train 1000 steps, use 33.9460 s, train loss 0.0732
Train 1100 steps, use 33.9581 s, train loss 0.06083
Train 1200 steps, use 33.7295 s, train loss 0.04783
Train 1300 steps, use 33.9351 s, train loss 0.04094
Train 1400 steps, use 33.6706 s, train loss 0.03436
Train 1500 steps, use 33.6868 s, train loss 0.0283

The training loss is continuously decreasing, demonstrating that the network is effectively training.

# visualize the training losses
plt.plot(trainer.train_losses)
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.show()
../_images/7eed0033f5208eddd8ab5a52b15210eef6514c3e61dae9066b50a428fb4c8bbd.png

Let’s visualize the trained spiking neurons.

import numpy as np
from matplotlib.gridspec import GridSpec

def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
  plt.figure(figsize=(15, 8))
  gs = GridSpec(*dim)
  mem = 1. * mem
  if spk is not None:
    mem[spk > 0.0] = spike_height
  mem = bm.as_numpy(mem)
  for i in range(np.prod(dim)):
    if i == 0:
      a0 = ax = plt.subplot(gs[i])
    else:
      ax = plt.subplot(gs[i], sharey=a0)
    ax.plot(mem[i])
  plt.tight_layout()
  plt.show()
# get the prediction results and neural activity

runner = bp.dyn.DSRunner(
    net,
    monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}
)
out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True)
plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
../_images/5ac9b488c1f6a3a34f1eb2da6d97532cb222200b62d9b9985e326037219d4465.png
# the prediction accuracy

m = bm.max(out, axis=1)  # max over time
am = bm.argmax(m, axis=1)  # argmax over output units
acc = bm.mean(y_data == am)  # compare to labels
print("Accuracy %.3f" % acc)
Accuracy 0.992