brainpy.train.offline.OfflineTrainer#

class brainpy.train.offline.OfflineTrainer(target, fit_method=None, **kwargs)[source]#

Offline trainer for models with recurrent dynamics.

Parameters
  • target (DynamicalSystem) – The target model to train.

  • fit_method (OfflineAlgorithm, Callable, dict, str) –

    The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm.

    Like, fit_method='ridge' means using the Ridge regression method. All supported fitting methods can be obtained through brainpy.nn.runners.get_supported_offline_methods()

    • It can be a dict, whose “name” item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example, fit_method={'name': 'ridge', 'beta': 1e-4}.

    • It can be an instance of brainpy.nn.runners.OfflineAlgorithm. For example, fit_meth=bp.nn.runners.RidgeRegression(beta=1e-5).

    • It can also be a callable function, which receives three arguments “targets”, “x” and “y”. For example, fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0].

  • **kwargs – The other general parameters for RNN running initialization.

__init__(target, fit_method=None, **kwargs)[source]#

Methods

__init__(target[, fit_method])

build_monitors(return_without_idx, ...)

f_predict([shared_args])

f_train([shared_args])

Get training function.

fit(train_data[, reset_state, shared_args])

Fit the target model according to the given training and testing data.

format_monitors()

predict(inputs[, reset_state, shared_args, ...])

Prediction function.

reset_state()

run(*args, **kwargs)

Predict a series of input data with the given target model.

Attributes