Training a Brain Dynamics Model#

@Chaoming Wang

In recent years, we saw the revolution that training a dynamical system from data or tasks has provided important insights to understand brain functions. To support this, BrainPy provides various interfaces to help users train dynamical systems.

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd

bm.enable_x64()

bm.set_platform('cpu')
bp.__version__
'2.3.8'
import matplotlib.pyplot as plt

Training a reservoir network model#

For an echo state network, we have three components: an input node (“I”), a reservoir node (“R”) for dimension expansion, and an output node (“O”) for linear readout.

(Gauthier, et. al., Nature Communications, 2021) has proposed a next generation reservoir computing (NG-RC) model by using nonlinear vector autoregression (NVAR).

The difference between the two models is illustrated in the following figure.

(A) A traditional RC processes time-series data using an artificial recurrent neural network. (B) The NG-RC performs a forecast using a linear weight of time-delay states of the time series data and nonlinear functionals of this data.

Here, let’s implement a next generation reservoir model to predict the chaotic time series, named as Lorenz attractor. Particularly, we expect the network has the ability to predict \(P(t+l)\) from \(P(t)\), where \(l\) is the length of the prediction ahead.

dt = 0.01
data = bd.chaos.LorenzEq(100, dt=dt)
plt.figure(figsize=(10, 5))
plt.subplot(311)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.xs.flatten()))
plt.ylabel('x')
plt.subplot(312)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.ys.flatten()))
plt.ylabel('y')
plt.subplot(313)
plt.plot(bm.as_numpy(data.ts), bm.as_numpy(data.zs.flatten()))
plt.ylabel('z')
plt.show()
../_images/4281e0be299d30d48d425f651005c0496e61b7d80a4ad8607e06299665f944b8.png

Let’s first create a function to get the data.

def get_subset(data, start, end):
    res = {'x': data.xs[start: end],
           'y': data.ys[start: end],
           'z': data.zs[start: end]}
    res = bm.hstack([res['x'], res['y'], res['z']])
    return res.reshape((1, ) + res.shape)

To accomplish this task, we implement a next-generation reservoir model of 4 delay history information with stride of 5, and their quadratic polynomial monomials, same as (Gauthier, et. al., Nature Communications, 2021).

class NGRC(bp.DynamicalSystem):
  def __init__(self, num_in, num_out):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5)
    self.o = bp.layers.Dense(self.r.num_out, num_out, mode=bm.training_mode)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    # other arguments like "x" can be customized by users.
    return self.o(sha, self.r(sha, x))
with bm.environment(bm.batching_mode):
    model = NGRC(num_in=3, num_out=3)

Moreover, we use Ridge Regression method to train the model.

trainer = bp.RidgeTrainer(model, alpha=1e-6)

We warm-up the network with 20 ms.

warmup_data = get_subset(data, 0, int(20/dt))

outs = trainer.predict(warmup_data)

# outputs should be an array with the shape of
# (num_batch, num_time, num_out)
outs.shape
(1, 2000, 3)

The training data is the time series from 20 ms to 80 ms. We want the network has the ability to forecast 1 time step ahead.

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+1, int(80/dt)+1)

_ = trainer.fit([x_train, y_train])

Then we test the trained network with the next 20 ms.

x_test = get_subset(data, int(80/dt), int(100/dt)-1)
y_test = get_subset(data, int(80/dt) + 1, int(100/dt))

predictions = trainer.predict(x_test)

bp.losses.mean_squared_error(y_test, predictions)
Array(5.36414848e-10, dtype=float64)
def plot_difference(truths, predictions):
    truths = bm.as_numpy(truths)
    predictions = bm.as_numpy(predictions)

    plt.subplot(311)
    plt.plot(truths[0, :, 0], label='Ground Truth')
    plt.plot(predictions[0, :, 0], label='Prediction')
    plt.ylabel('x')
    plt.legend()
    plt.subplot(312)
    plt.plot(truths[0, :, 1], label='Ground Truth')
    plt.plot(predictions[0, :, 1], label='Prediction')
    plt.ylabel('y')
    plt.legend()
    plt.subplot(313)
    plt.plot(truths[0, :, 2], label='Ground Truth')
    plt.plot(predictions[0, :, 2], label='Prediction')
    plt.ylabel('z')
    plt.legend()
    plt.show()
plot_difference(y_test, predictions)
../_images/4f0a8c4ce7286da7df1973a3e2aa0307f36c3fc2a0b628b3975d5efdaa4a7baf.png

We can make the task harder to forecast 10 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
outs = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+10, int(80/dt)+10)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-10)
y_test = get_subset(data, int(80/dt) + 10, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
../_images/61e6eb01859faf3c56d7370bcbbcd26d5f875f895f70a15be3491815ebbce9e5.png

Or forecast 100 time step ahead.

warmup_data = get_subset(data, 0, int(20/dt))
_ = trainer.predict(warmup_data)

x_train = get_subset(data, int(20/dt), int(80/dt))
y_train = get_subset(data, int(20/dt)+100, int(80/dt)+100)
trainer.fit([x_train, y_train])

x_test = get_subset(data, int(80/dt), int(100/dt)-100)
y_test = get_subset(data, int(80/dt) + 100, int(100/dt))
predictions = trainer.predict(x_test)

plot_difference(y_test, predictions)
../_images/a4030f3933e4b387ac3511a31d6bfc6a4005d934d3fe429f81330a0329b992a1.png

As you see, forecasting larger time step makes the learning more difficult.

Training an artificial recurrent network#

In recent years, artificial recurrent neural networks trained with back propagation through time (BPTT) have been a useful tool to study the network mechanism of brain functions. To support training networks with BPTT, BrainPy provides brainpy.train.BPTT interface.

Here, we demonstrate how to train an artificial recurrent neural network by using a white noise integration task. In this task, we want our trained RNN model has the ability to integrate white noise. For example, if we have a time series of noise data,

noises = bm.random.normal(0, 0.2, size=10)

plt.figure(figsize=(8, 2))
plt.plot(noises.to_numpy().flatten())
plt.show()
../_images/a2dbc7effda2ec8cc0b7e4d5fb53dd124016c0ca2404e0f5e8165f271499a813.png

Now, we want to get a model which can integrate the noise bm.cumsum(noises) * dt:

dt = 0.1
integrals = bm.cumsum(noises) * dt

plt.figure(figsize=(8, 2))
plt.plot(integrals.to_numpy().flatten())
plt.show()
../_images/ca7ae24949bb62d50314a101ebd825afa824945082853f05259547fed612362c.png

Here, we first define a task which generates the input data and the target integration results.

from functools import partial

dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128


@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def build_inputs_and_targets(mean=0.025, scale=0.01):
  # Create the white noise input
  sample = bm.random.normal(size=(num_batch, 1, 1))
  bias = mean * 2.0 * (sample - 0.5)
  samples = bm.random.normal(size=(num_batch, num_step, 1))
  noise_t = scale / dt ** 0.5 * samples
  inputs = bias + noise_t
  targets = bm.cumsum(inputs, axis=1)
  return inputs, targets


def train_data():
  for _ in range(100):
    yield build_inputs_and_targets()

Then, we create and initialize the model. Note here we need the model train its initial state, so we need set state_trainable=True for the used VanillaRNN instance.

class RNN(bp.DynamicalSystem):
  def __init__(self, num_in, num_hidden):
    super(RNN, self).__init__()
    self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
    self.out = bp.layers.Dense(num_hidden, 1)

  def update(self, sha, x):
    # "sha" is the arguments shared across all nodes.
    return self.out(sha, self.rnn(sha, x))


with bm.training_environment():
    model = RNN(1, 100)

brainpy.nn.BPTT trainer receives a loss function setting, and an optimizer setting. Loss function can be selected from the brainpy.losses module, or it can be a callable function receives (predictions, targets) argument. Optimizer setting must be an instance of brainpy.optim.Optimizer.

Here we define a loss function which use Mean Squared Error (MSE) to measure the error between the targets and the predictions. We also apply a L2 regularization.

# define loss function
def loss(predictions, targets, l2_reg=2e-4):
    mse = bp.losses.mean_squared_error(predictions, targets)
    l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
    return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.BPTT(model,
                 loss_fun=loss,
                 optimizer=opt)
# train the model
trainer.fit(train_data, num_epoch=30)
Train 0 epoch, use 2.2865 s, loss 0.3474464803593203
Train 1 epoch, use 0.9384 s, loss 0.026605883508514516
Train 2 epoch, use 0.9478 s, loss 0.021708405535614907
Train 3 epoch, use 0.9892 s, loss 0.02143528795935897
Train 4 epoch, use 0.9300 s, loss 0.02107475878707875
Train 5 epoch, use 0.9298 s, loss 0.020932997073748006
Train 6 epoch, use 0.9566 s, loss 0.020855205349191275
Train 7 epoch, use 0.9072 s, loss 0.020789264013002805
Train 8 epoch, use 0.9631 s, loss 0.02066686516861348
Train 9 epoch, use 0.8968 s, loss 0.020621776997250745
Train 10 epoch, use 0.9438 s, loss 0.020569378459587808
Train 11 epoch, use 0.8785 s, loss 0.02049849876797083
Train 12 epoch, use 0.8515 s, loss 0.02047079743844964
Train 13 epoch, use 0.8533 s, loss 0.02039058010677752
Train 14 epoch, use 0.9209 s, loss 0.02035540442302181
Train 15 epoch, use 0.9844 s, loss 0.0203037559193207
Train 16 epoch, use 1.0006 s, loss 0.02025348558545429
Train 17 epoch, use 0.9397 s, loss 0.020213293431676486
Train 18 epoch, use 0.8990 s, loss 0.0201581882182178
Train 19 epoch, use 0.8812 s, loss 0.020140154456911596
Train 20 epoch, use 0.8976 s, loss 0.020074232282055494
Train 21 epoch, use 0.8953 s, loss 0.020042431126861646
Train 22 epoch, use 0.8584 s, loss 0.020006458781740344
Train 23 epoch, use 0.9393 s, loss 0.019969488637816196
Train 24 epoch, use 0.9012 s, loss 0.019924267557334032
Train 25 epoch, use 0.8897 s, loss 0.019882438760702597
Train 26 epoch, use 0.9212 s, loss 0.019851161733026937
Train 27 epoch, use 0.8935 s, loss 0.019828587850703315
Train 28 epoch, use 0.9329 s, loss 0.01978077887574506
Train 29 epoch, use 0.8505 s, loss 0.019749579454078393

The training losses can be retrieved by .get_hist_metric() function.

plt.figure(figsize=(8, 3))
plt.plot(trainer.get_hist_metric(metric='loss'))
plt.xlabel('Number of Training Step')
plt.ylabel('Training Loss')
plt.show()
../_images/0305b7e6a7099de6e76755d32410eecc9fe9ccf8c1297db9a037a2ed681078f9.png

Finally, let’s try the trained network, and test whether it can generate the correct integration results.

model.reset_state(num_batch)
x, y = build_inputs_and_targets()
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()
../_images/ea16bd118132ad9ac37192293bcc2554ec366daf84405c4bff0a62f111aa3bff.png

Training a spiking neural network#

BrainPy also supports to train spiking neural networks.

In the following, we demonstrate how to use back-propagation algorithms to train spiking neurons with a simple example.

Our model is a simple three layer model:

  • an input layer

  • a LIF layer

  • a readout layer

The synaptic connection between each layer is the Exponenetial synapse model.

class SNN(bp.Network):
  def __init__(self, num_in, num_rec, num_out):
    super(SNN, self).__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out

    # neuron groups
    self.i = bp.neurons.InputGroup(num_in)
    self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)

    # synapse: i->r
    self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), tau=10.,
                                       output=bp.synouts.CUBA(target_var=None),
                                       g_max=bp.init.KaimingNormal(scale=20.))
    # synapse: r->o
    self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), tau=10.,
                                       output=bp.synouts.CUBA(target_var=None),
                                       g_max=bp.init.KaimingNormal(scale=20.))

    # whole model
    self.model = bp.Sequential(self.i, self.i2r, self.r, self.r2o, self.o)

  def update(self, tdi, spike):
    self.model(tdi, spike)
    return self.o.V.value
with bm.training_environment():
    net = SNN(100, 10, 2)  # out task is a two label classification task

We try to use this simple task to classify a random spiking data into two classes.

num_step = 2000
num_sample = 256
freq = 5  # Hz
mask = bm.random.rand(num_sample, num_step, net.num_in)
x_data = bm.zeros((num_sample, num_step, net.num_in))
x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)

def get_data():
    for _ in range(1):
        yield x_data, y_data

Same as the training of artificial recurrent neural networks, we use Adam optimizer and cross entropy loss to train the model.

opt = bp.optim.Adam(lr=2e-3)

def loss(predicts, targets):
  return bp.losses.cross_entropy_loss(bm.max(predicts, axis=1), targets)


trainer = bp.BPTT(net,
                  loss_fun=loss,
                  optimizer=opt)
trainer.fit(train_data=get_data,
            num_report=10,
            num_epoch=200)
Train 10 steps, use 0.5471 s, loss 0.8119578839372008
Train 20 steps, use 0.5587 s, loss 0.7544552171954115
Train 30 steps, use 0.5898 s, loss 0.7020412181865148
Train 40 steps, use 0.5782 s, loss 0.6750034517542842
Train 50 steps, use 0.5996 s, loss 0.6638762207402613
Train 60 steps, use 0.5469 s, loss 0.6457876645959164
Train 70 steps, use 0.5286 s, loss 0.625173858852444
Train 80 steps, use 0.5480 s, loss 0.6108308776234428
Train 90 steps, use 0.5672 s, loss 0.5928884702255813
Train 100 steps, use 0.5597 s, loss 0.5864131114195933
Train 110 steps, use 0.5726 s, loss 0.5829393663656248
Train 120 steps, use 0.5730 s, loss 0.5648404392129361
Train 130 steps, use 0.5711 s, loss 0.5471623937603349
Train 140 steps, use 0.5665 s, loss 0.5335702569949301
Train 150 steps, use 0.5618 s, loss 0.5198388588329248
Train 160 steps, use 0.5529 s, loss 0.5154762874039034
Train 170 steps, use 0.5734 s, loss 0.49874691621484557
Train 180 steps, use 0.5523 s, loss 0.4925224810657929
Train 190 steps, use 0.5375 s, loss 0.48411000993072006
Train 200 steps, use 0.5524 s, loss 0.46809022931937005

The training loss is continuously decreasing, demonstrating that the network is effectively training.

# visualize the training losses
plt.plot(trainer.get_hist_metric())
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.show()
../_images/18d87265afca71a9d510724f00fd2343a5660c2cb7e557081e68ea91e63f1484.png

Let’s visualize the trained spiking neurons.

import numpy as np
from matplotlib.gridspec import GridSpec

def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
  plt.figure(figsize=(15, 8))
  gs = GridSpec(*dim)
  mem = 1. * mem
  if spk is not None:
    mem[spk > 0.0] = spike_height
  mem = bm.as_numpy(mem)
  for i in range(np.prod(dim)):
    if i == 0:
      a0 = ax = plt.subplot(gs[i])
    else:
      ax = plt.subplot(gs[i], sharey=a0)
    ax.plot(mem[i])
  plt.tight_layout()
  plt.show()
# get the prediction results and neural activity

runner = bp.DSRunner(
    net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}
)
out = runner.run(inputs=x_data, reset_state=True)
plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
../_images/31de1204df607d22d0cff42884f795b1fc16ecec2c56aa4635a3e306c2ecabdb.png
# the prediction accuracy

m = bm.max(out, axis=1)  # max over time
am = bm.argmax(m, axis=1)  # argmax over output units
acc = bm.mean(y_data == am)  # compare to labels
print("Accuracy %.3f" % acc)
Accuracy 0.672