# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict, Sequence, Union, Callable, Any
import brainstate.environ
import jax
import numpy as np
import tqdm.auto
import brainpy.math as bm
from brainpy import tools
from brainpy.algorithms.offline import get, RidgeRegression, OfflineAlgorithm
from brainpy.context import share
from brainpy.dynsys import DynamicalSystem
from brainpy.mixin import SupportOffline
from brainpy.runners import _call_fun_with_share
from brainpy.types import ArrayType, Output
from ._utils import format_ys
from .base import DSTrainer
__all__ = [
'OfflineTrainer',
'RidgeTrainer',
]
[docs]
class OfflineTrainer(DSTrainer):
"""Offline trainer for models with recurrent dynamics.
For more parameters, users should refer to :py:class:`~.DSRunner`.
Parameters::
target: DynamicalSystem
The target model to train.
fit_method: OfflineAlgorithm, Callable, dict, str
The fitting method applied to the target model.
- It can be a string, which specify the shortcut name of the training algorithm.
Like, ``fit_method='ridge'`` means using the Ridge regression method.
All supported fitting methods can be obtained through
:py:func:`~get_supported_offline_methods`.
- It can be a dict, whose "name" item specifies the name of the training algorithm,
and the others parameters specify the initialization parameters of the algorithm.
For example, ``fit_method={'name': 'ridge', 'alpha': 0.1}``.
- It can be an instance of :py:class:`brainpy.algorithms.OfflineAlgorithm`.
For example, ``fit_meth=bp.algorithms.RidgeRegression(alpha=0.1)``.
- It can also be a callable function, which receives three arguments "targets", "x" and "y".
For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``.
kwargs: Any
Other general parameters please see :py:class:`~.DSRunner`.
"""
def __init__(
self,
target: DynamicalSystem,
fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None,
**kwargs
):
self._true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True)
kwargs['numpy_mon_after_run'] = False
super().__init__(target=target, **kwargs)
# get all trainable nodes
nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique()
self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)])
if len(self.train_nodes) == 0:
raise ValueError('Found no trainable nodes.')
# check the required interface in the trainable nodes
self._check_interface()
# training method
if fit_method is None:
fit_method = RidgeRegression(alpha=1e-7)
elif isinstance(fit_method, str):
fit_method = get(fit_method)()
elif isinstance(fit_method, dict):
name = fit_method.pop('name')
fit_method = get(name)(**fit_method)
if not callable(fit_method):
raise ValueError(f'"train_method" must be an instance of callable function, '
f'but we got {type(fit_method)}.')
self.fit_method = fit_method
# set the training method
for node in self.train_nodes:
node.offline_fit_by = fit_method
# training function
self._jit_fun_train = bm.jit(self._fun_train)
def __repr__(self):
name = self.__class__.__name__
prefix = ' ' * len(name)
return (f'{name}(target={self.target}, \n\t'
f'{prefix}fit_method={self.fit_method})')
[docs]
def predict(
self,
inputs: Any,
reset_state: bool = False,
shared_args: Dict = None,
eval_time: bool = False
) -> Output:
"""Prediction function.
What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that
the `inputs_are_batching` is default `True`.
Parameters::
inputs: ArrayType
The input values.
reset_state: bool
Reset the target state before running.
eval_time: bool
Whether we evaluate the running time or not?
shared_args: dict
The shared arguments across nodes.
Returns::
output: ArrayType
The running output.
"""
outs = super().predict(inputs=inputs, reset_state=reset_state,
eval_time=eval_time, shared_args=shared_args)
for node in self.train_nodes:
node.fit_record.clear()
return outs
[docs]
def fit(
self,
train_data: Sequence,
reset_state: bool = False,
shared_args: Dict = None,
) -> Output:
"""Fit the target model according to the given training and testing data.
Parameters::
train_data: sequence of data
It should be a pair of `(X, Y)` train set.
- ``X``: should be a tensor or a dict of tensors with the shape of
`(num_sample, num_time, num_feature)`, where `num_sample` is
the number of samples, `num_time` is the number of the time step,
and `num_feature` is the number of features.
- ``Y``: Target values. A tensor or a dict of tensors.
- If the shape of each tensor is `(num_sample, num_feature)`,
then we will only fit the model with the only last output.
- If the shape of each tensor is `(num_sample, num_time, num_feature)`,
then the fitting happens on the whole data series.
reset_state: bool
Whether reset the initial states of the target model.
shared_args: dict
The shared keyword arguments for the target models.
"""
with brainstate.environ.context(fit=True):
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
shared_args = tools.DotDict(shared_args)
# checking training and testing data
if not isinstance(train_data, (list, tuple)):
raise ValueError(f"{self.__class__.__name__} only support "
f"training data with the format of (X, Y) pair, "
f"but we got a {type(train_data)}.")
if len(train_data) != 2:
raise ValueError(f"{self.__class__.__name__} only support "
f"training data with the format of (X, Y) pair, "
f"but we got a sequence with length {len(train_data)}")
xs, ys = train_data
# prediction, get all needed data
shared_args['fit'] = shared_args.get('fit', True)
outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args)
# check target data
ys = format_ys(self, ys)
# init progress bar
if self.progress_bar:
self._pbar = tqdm.auto.tqdm(total=len(self.train_nodes))
self._pbar.set_description(f"Train {len(self.train_nodes)} nodes: ", refresh=True)
# training
monitor_data = dict()
for node in self.train_nodes:
key = f'{node.name}-fit_record'
monitor_data[key] = self.mon.get(key)
run_fun = self._jit_fun_train if self.jit['fit'] else self._fun_train
shared_args['fit'] = True
run_fun(monitor_data, ys)
del monitor_data
# close the progress bar
if self.progress_bar:
self._pbar.close()
# final things
for node in self.train_nodes:
# Only pop if the key exists
fit_record_key = f'{node.name}-fit_record'
if fit_record_key in self.mon:
self.mon.pop(fit_record_key)
node.fit_record.clear() # clear fit records
if self._true_numpy_mon_after_run:
for key in self.mon.keys():
self.mon[key] = np.asarray(self.mon[key])
return outs
def _fun_train(self,
monitor_data: Dict[str, ArrayType],
target_data: Dict[str, ArrayType],
shared_args: Dict = None):
if shared_args is None:
shared_args = dict()
share.save(**shared_args)
for node in self.train_nodes:
fit_record_key = f'{node.name}-fit_record'
fit_record = monitor_data.get(fit_record_key, None)
targets = target_data[node.name]
node.offline_fit(targets, fit_record)
if self.progress_bar:
jax.debug.callback(lambda *args: self._pbar.update(), ())
def _step_func_monitor(self):
res = dict()
for key, val in self._monitors.items():
if callable(val):
res[key] = _call_fun_with_share(val)
else:
(variable, idx) = val
if idx is None:
res[key] = variable.value
else:
res[key] = variable[bm.asarray(idx)]
if share.load('fit'):
for node in self.train_nodes:
res[f'{node.name}-fit_record'] = node.fit_record
return res
def _check_interface(self):
for node in self.train_nodes:
if not isinstance(node, SupportOffline):
raise TypeError(
f'''
The node
{node}
is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}.
However, {self.__class__.__name__} only support training nodes that are instances of {SupportOffline}.
'''
)
[docs]
class RidgeTrainer(OfflineTrainer):
"""Trainer of ridge regression, also known as regression with Tikhonov regularization.
For more parameters, users should refer to :py:class:`~.DSRunner`.
Parameters::
target: TrainingSystem, DynamicalSystem
The target model.
alpha: float
The regularization coefficient.
"""
def __init__(self, target, alpha=1e-7, **kwargs):
super().__init__(target=target, fit_method=dict(name='ridge', alpha=alpha), **kwargs)