Source code for brainpy._src.train.offline

# -*- coding: utf-8 -*-

from typing import Dict, Sequence, Union, Callable, Any

import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap

import brainpy.math as bm
from brainpy import tools
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.mixin import SupportOffline
from brainpy._src.runners import _call_fun_with_share
from brainpy.algorithms.offline import get, RidgeRegression, OfflineAlgorithm
from brainpy.types import ArrayType, Output
from ._utils import format_ys
from .base import DSTrainer

__all__ = [
  'OfflineTrainer',
  'RidgeTrainer',
]


[docs] class OfflineTrainer(DSTrainer): """Offline trainer for models with recurrent dynamics. For more parameters, users should refer to :py:class:`~.DSRunner`. Parameters ---------- target: DynamicalSystem The target model to train. fit_method: OfflineAlgorithm, Callable, dict, str The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm. Like, ``fit_method='ridge'`` means using the Ridge regression method. All supported fitting methods can be obtained through :py:func:`~get_supported_offline_methods`. - It can be a dict, whose "name" item specifies the name of the training algorithm, and the others parameters specify the initialization parameters of the algorithm. For example, ``fit_method={'name': 'ridge', 'alpha': 0.1}``. - It can be an instance of :py:class:`brainpy.algorithms.OfflineAlgorithm`. For example, ``fit_meth=bp.algorithms.RidgeRegression(alpha=0.1)``. - It can also be a callable function, which receives three arguments "targets", "x" and "y". For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``. kwargs: Any Other general parameters please see :py:class:`~.DSRunner`. """ def __init__( self, target: DynamicalSystem, fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None, **kwargs ): self._true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True) kwargs['numpy_mon_after_run'] = False super().__init__(target=target, **kwargs) # get all trainable nodes nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)]) if len(self.train_nodes) == 0: raise ValueError('Found no trainable nodes.') # check the required interface in the trainable nodes self._check_interface() # training method if fit_method is None: fit_method = RidgeRegression(alpha=1e-7) elif isinstance(fit_method, str): fit_method = get(fit_method)() elif isinstance(fit_method, dict): name = fit_method.pop('name') fit_method = get(name)(**fit_method) if not callable(fit_method): raise ValueError(f'"train_method" must be an instance of callable function, ' f'but we got {type(fit_method)}.') self.fit_method = fit_method # set the training method for node in self.train_nodes: node.offline_fit_by = fit_method # training function self._jit_fun_train = bm.jit(self._fun_train, static_argnames=['shared_args']) def __repr__(self): name = self.__class__.__name__ prefix = ' ' * len(name) return (f'{name}(target={self.target}, \n\t' f'{prefix}fit_method={self.fit_method})')
[docs] def predict( self, inputs: Any, reset_state: bool = False, shared_args: Dict = None, eval_time: bool = False ) -> Output: """Prediction function. What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that the `inputs_are_batching` is default `True`. Parameters ---------- inputs: ArrayType The input values. reset_state: bool Reset the target state before running. eval_time: bool Whether we evaluate the running time or not? shared_args: dict The shared arguments across nodes. Returns ------- output: ArrayType The running output. """ outs = super().predict(inputs=inputs, reset_state=reset_state, eval_time=eval_time, shared_args=shared_args) for node in self.train_nodes: node.fit_record.clear() return outs
[docs] def fit( self, train_data: Sequence, reset_state: bool = False, shared_args: Dict = None, ) -> Output: """Fit the target model according to the given training and testing data. Parameters ---------- train_data: sequence of data It should be a pair of `(X, Y)` train set. - ``X``: should be a tensor or a dict of tensors with the shape of `(num_sample, num_time, num_feature)`, where `num_sample` is the number of samples, `num_time` is the number of the time step, and `num_feature` is the number of features. - ``Y``: Target values. A tensor or a dict of tensors. - If the shape of each tensor is `(num_sample, num_feature)`, then we will only fit the model with the only last output. - If the shape of each tensor is `(num_sample, num_time, num_feature)`, then the fitting happens on the whole data series. reset_state: bool Whether reset the initial states of the target model. shared_args: dict The shared keyword arguments for the target models. """ if shared_args is None: shared_args = dict() shared_args['fit'] = shared_args.get('fit', True) shared_args = tools.DotDict(shared_args) # checking training and testing data if not isinstance(train_data, (list, tuple)): raise ValueError(f"{self.__class__.__name__} only support " f"training data with the format of (X, Y) pair, " f"but we got a {type(train_data)}.") if len(train_data) != 2: raise ValueError(f"{self.__class__.__name__} only support " f"training data with the format of (X, Y) pair, " f"but we got a sequence with length {len(train_data)}") xs, ys = train_data # prediction, get all needed data shared_args['fit'] = shared_args.get('fit', False) outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args) # check target data ys = format_ys(self, ys) # init progress bar if self.progress_bar: self._pbar = tqdm.auto.tqdm(total=len(self.train_nodes)) self._pbar.set_description(f"Train {len(self.train_nodes)} nodes: ", refresh=True) # training monitor_data = dict() for node in self.train_nodes: key = f'{node.name}-fit_record' monitor_data[key] = self.mon.get(key) run_fun = self._jit_fun_train if self.jit['fit'] else self._fun_train shared_args['fit'] = True run_fun(monitor_data, ys, shared_args=shared_args) del monitor_data # close the progress bar if self.progress_bar: self._pbar.close() # final things for node in self.train_nodes: self.mon.pop(f'{node.name}-fit_record') node.fit_record.clear() # clear fit records if self._true_numpy_mon_after_run: for key in self.mon.keys(): self.mon[key] = np.asarray(self.mon[key]) return outs
def _fun_train(self, monitor_data: Dict[str, ArrayType], target_data: Dict[str, ArrayType], shared_args: Dict = None): if shared_args is None: shared_args = dict() share.save(**shared_args) for node in self.train_nodes: fit_record = monitor_data[f'{node.name}-fit_record'] targets = target_data[node.name] node.offline_fit(targets, fit_record) if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ()) def _step_func_monitor(self): res = dict() for key, val in self._monitors.items(): if callable(val): res[key] = _call_fun_with_share(val) else: (variable, idx) = val if idx is None: res[key] = variable.value else: res[key] = variable[bm.asarray(idx)] if share.load('fit'): for node in self.train_nodes: res[f'{node.name}-fit_record'] = node.fit_record return res def _check_interface(self): for node in self.train_nodes: if not isinstance(node, SupportOffline): raise TypeError( f''' The node {node} is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}. However, {self.__class__.__name__} only support training nodes that are instances of {SupportOffline}. ''' )
[docs] class RidgeTrainer(OfflineTrainer): """Trainer of ridge regression, also known as regression with Tikhonov regularization. For more parameters, users should refer to :py:class:`~.DSRunner`. Parameters ---------- target: TrainingSystem, DynamicalSystem The target model. alpha: float The regularization coefficient. """ def __init__(self, target, alpha=1e-7, **kwargs): super().__init__(target=target, fit_method=dict(name='ridge', alpha=alpha), **kwargs)