brainpy.train.online.OnlineTrainer#

class brainpy.train.online.OnlineTrainer(target, fit_method=None, **kwargs)[source]#

Online trainer for models with recurrent dynamics.

Parameters
  • target (DynamicalSystem) – The target model to train.

  • fit_method (OnlineAlgorithm, Callable, dict, str) –

    The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm.

    Like, fit_method='ridge' means using the RLS method. All supported fitting methods can be obtained through brainpy.nn.runners.get_supported_online_methods()

    • It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example, fit_method={'name': 'ridge', 'beta': 1e-4}.

    • It can be an instance of brainpy.nn.runners.OnlineAlgorithm. For example, fit_meth=bp.nn.runners.RLS(alpha=1e-5).

    • It can also be a callable function.

  • **kwargs – The other general parameters for RNN running initialization.

__init__(target, fit_method=None, **kwargs)[source]#

Methods

__init__(target[, fit_method])

build_monitors(return_without_idx, ...)

f_predict([shared_args])

fit(train_data[, reset_state, shared_args])

rtype

TypeVar(Output)

format_monitors()

predict(inputs[, reset_state, shared_args, ...])

Prediction function.

reset_state()

run(*args, **kwargs)

Predict a series of input data with the given target model.

Attributes