Training with Offline Algorithms#
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
import matplotlib.pyplot as plt
bm.set_environment(x64=True, mode=bm.batching_mode)
BrainPy provides many offline training algorithms can help users train models such as reservoir computing models.
Train a reservoir model#
Here, we train an echo-state machine to predict chaotic dynamics. This example is used to illustrate how to use brainpy.train.OfflineTrainer
We first get the training dataset.
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
# Training data must have batch size, here the batch is 1
return res.reshape((1, ) + res.shape)
dt = 0.01
t_warmup, t_train, t_test = 5., 100., 50. # ms
num_warmup, num_train, num_test = int(t_warmup/dt), int(t_train/dt), int(t_test/dt)
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 5)
X_train = get_subset(lorenz_series, num_warmup - 5, num_warmup + num_train - 5)
X_test = get_subset(lorenz_series,
num_warmup + num_train - 5,
num_warmup + num_train + num_test - 5)
# out target data is the activity ahead of 5 time steps
Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
Y_test = get_subset(lorenz_series,
num_warmup + num_train,
num_warmup + num_train + num_test)
Then, we try to build an echo-state machine to predict the chaotic dynamics ahead of five time steps.
class ESN(bp.DynamicalSystemNS):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in, num_hidden,
Win_initializer=bp.init.Uniform(-0.1, 0.1),
self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal(),
def update(self, x):
return x >> self.r >> self.o
model = ESN(3, 100, 3)
Here, we use ridge regression as the training algorithm to train the chaotic model.
trainer = bp.OfflineTrainer(model, fit_method=bp.algorithms.RidgeRegression(1e-7), dt=dt)
# first warmup the reservoir
_ = trainer.predict(X_warmup)
# then fit the reservoir model
_ =[X_train, Y_train])
def plot_lorenz(ground_truth, predictions):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(121, projection='3d')
ax.set_title("Generated attractor")
ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Real attractor")
ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
# finally, predict the model with the test data
outputs = trainer.predict(X_test)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS: 0.858903742335844

Switch different training algorithms#
supports easy switch of training algorithms. You just need provide the fit_method
argument when instantiating an offline trainer.
Many offline algorithms, like linear regression, ridge regression, and Lasso regression, have been provided as the build-in models.
model = ESN(3, 100, 3)
trainer = bp.OfflineTrainer(model, fit_method=bp.algorithms.LinearRegression())
_ = trainer.predict(X_warmup)
_ =[X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())

Customize your training algorithms#
also supports to train models with your customized training algorithms.
Specifically, the customization of an offline algorithm should follow the interface of brainpy.algorithms.OfflineAlgorithm
, in which users specify how the model parameters are calculated according to the input, prediction, and target data.
For instance, here we use the Lasso
model provided in scikit-learn package to define an offline training algorithm.
from sklearn.linear_model import Lasso
class LassoAlgorithm(bp.algorithms.OfflineAlgorithm):
def __init__(self, alpha=1., max_iter=int(1e4)):
super(LassoAlgorithm, self).__init__()
self.model = Lasso(alpha=alpha, max_iter=max_iter)
def __call__(self, y, x, outs=None):
x = bm.as_numpy(x[0])
y = bm.as_numpy(y[0])
x_new =, y).coef_.T
return bm.expand_dims(bm.asarray(x_new), 1)
model = ESN(3, 100, 3)
# note here scikit-learn algorithms does not support JAX jit,
# therefore the "jit" of the "fit" phase is set to be False.
trainer = bp.OfflineTrainer(model,
jit={'fit': False})
_ = trainer.predict(X_warmup)
_ =[X_train, Y_train])
outputs = trainer.predict(X_test)
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())