OnlineTrainer

OnlineTrainer#

class brainpy.OnlineTrainer(target, fit_method=None, **kwargs)[source]#

Online trainer for models with recurrent dynamics.

For more parameters, users should refer to DSRunner.

Parameters:
  • target (DynamicalSystem) – The target model to train.

  • fit_method (OnlineAlgorithm, Callable, dict, str) –

    The fitting method applied to the target model.

    • It can be a string, which specify the shortcut name of the training algorithm. Like, fit_method='rls' means using the RLS method. All supported fitting methods can be obtained through get_supported_online_methods().

    • It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example, fit_method={'name': 'rls', 'alpha': 0.1}.

    • It can be an instance of brainpy.algorithms.OnlineAlgorithm. For example, fit_meth=bp.algorithms.RLS(alpha=1e-5).

    • It can also be a callable function.

  • kwargs (Any) – Other general parameters please see DSRunner.

predict(inputs, reset_state=False, shared_args=None, eval_time=False)[source]#

Prediction function.

What’s different from predict() function in DynamicalSystem is that the inputs_are_batching is default True.

Parameters:
  • inputs (ArrayType) – The input values.

  • reset_state (bool) – Reset the target state before running.

  • shared_args (dict) – The shared arguments across nodes.

  • eval_time (bool) – Whether we evaluate the running time or not?

Returns:

output – The running output.

Return type:

ArrayType