brainpy.nn.runners.OnlineTrainer
brainpy.nn.runners.OnlineTrainer#
- class brainpy.nn.runners.OnlineTrainer(target, fit_method=None, **kwargs)[source]#
Online trainer for models with recurrent dynamics.
- Parameters
target (Node) – The target model to train.
fit_method (OnlineAlgorithm, Callable, dict, str) –
The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm.
Like,
fit_method='ridge'
means using the RLS method. All supported fitting methods can be obtained throughbrainpy.nn.runners.get_supported_online_methods()
It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example,
fit_method={'name': 'ridge', 'beta': 1e-4}
.It can be an instance of
brainpy.nn.runners.OnlineAlgorithm
. For example,fit_meth=bp.nn.runners.RLS(alpha=1e-5)
.It can also be a callable function.
**kwargs – The other general parameters for RNN running initialization.
Methods
__init__
(target[, fit_method])fit
(train_data[, test_data, reset, ...])predict
(xs[, forced_states, ...])Predict a series of input data with the given target model.
Attributes