brainpy.nn.runners.OfflineTrainer
brainpy.nn.runners.OfflineTrainer#
- class brainpy.nn.runners.OfflineTrainer(target, fit_method=None, **kwargs)[source]#
Offline trainer for models with recurrent dynamics.
- Parameters
target (Node) – The target model to train.
fit_method (OfflineAlgorithm, Callable, dict, str) –
The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm.
Like,
fit_method='ridge'
means using the Ridge regression method. All supported fitting methods can be obtained throughbrainpy.nn.runners.get_supported_offline_methods()
It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example,
fit_method={'name': 'ridge', 'beta': 1e-4}
.It can be an instance of
brainpy.nn.runners.OfflineAlgorithm
. For example,fit_meth=bp.nn.runners.RidgeRegression(beta=1e-5)
.It can also be a callable function, which receives three arguments “targets”, “x” and “y”. For example,
fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]
.
**kwargs – The other general parameters for RNN running initialization.
Methods
__init__
(target[, fit_method])f_train
([shared_kwargs])Get training function.
fit
(train_data[, test_data, reset, ...])Fit the target model according to the given training and testing data.
predict
(xs[, forced_states, ...])Predict a series of input data with the given target model.
Attributes