OnlineTrainer#
- class brainpy.OnlineTrainer(target, fit_method=None, **kwargs)[source]#
Online trainer for models with recurrent dynamics.
For more parameters, users should refer to
DSRunner.- Parameters:
target (
DynamicalSystem) – The target model to train.fit_method (
Union[OnlineAlgorithm,Callable,Dict,str]) –The fitting method applied to the target model.
It can be a string, which specify the shortcut name of the training algorithm. Like,
fit_method='rls'means using the RLS method. All supported fitting methods can be obtained throughget_supported_online_methods().It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example,
fit_method={'name': 'rls', 'alpha': 0.1}.It can be an instance of
brainpy.algorithms.OnlineAlgorithm. For example,fit_meth=bp.algorithms.RLS(alpha=1e-5).It can also be a callable function.
kwargs (
Any) – Other general parameters please seeDSRunner.
- predict(inputs, reset_state=False, shared_args=None, eval_time=False)[source]#
Prediction function.
What’s different from predict() function in
DynamicalSystemis that the inputs_are_batching is default True.- Parameters:
inputs (
Union[TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray),Sequence[TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray)],Dict[str,TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray)]]) – The input values.reset_state (
bool) – Reset the target state before running.shared_args (
Dict) – The shared arguments across nodes.eval_time (
bool) – Whether we evaluate the running time or not?
- Return type:
TypeVar(Output)- Returns:
output (
ArrayType) – The running output.