# -*- coding: utf-8 -*-
import time
from collections.abc import Iterable
from typing import Union, Dict, Callable, Sequence, Optional
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_map
from tqdm import tqdm
import brainpy.losses as losses
import brainpy.math as bm
from brainpy import optim
from brainpy import tools
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.running import constants as c
from brainpy.errors import UnsupportedError, NoLongerSupportError
from brainpy.types import ArrayType, Output
from ._utils import msg
from .base import DSTrainer
__all__ = [
'BPTT',
'BPFF',
]
def _is_brainpy_array(s):
return isinstance(s, bm.Array)
class BPTrainer(DSTrainer):
"""Trainer implementing back-propagation algorithm for supervised trasks.
For more parameters, users should refer to :py:class:`~.DSRunner`.
Parameters
----------
target: DynamicalSystem
The target model to train.
loss_fun: str, callable
The loss function. If it is a string, it should be the
function chosen from ``brainpy.losses`` module. Otherwise,
a callable function which receives argument of `(predicts, targets)`
should be provided.
loss_has_aux: bool
To indicate whether the `loss_fun` returns auxiliary data.
loss_auto_run: bool
pass
optimizer: optim.Optimizer
The optimizer used for training.
numpy_mon_after_run: bool
Make the monitored results as NumPy arrays.
logger: Any
A file-like object (stream); defaults to the current `sys.stdout`.
shuffle_data: bool
.. deprecated:: 2.2.4.1
Control the data shuffling by user self.
seed: int
.. deprecated:: 2.2.4.1
Control the data shuffling by user self.
kwargs: Any
Other general parameters please see :py:class:`~.DSRunner`.
"""
def __init__(
self,
target: DynamicalSystem,
loss_fun: Union[str, Callable], # loss function
optimizer: optim.Optimizer = None, # optimizer
loss_has_aux: bool = False, # loss auxiliary
loss_auto_run: bool = True, # loss auxiliary
# -------------
# API deprecated
seed: int = None, # deprecated
shuffle_data: bool = None, # deprecated
**kwargs,
):
super().__init__(target=target, **kwargs)
if shuffle_data is not None:
raise NoLongerSupportError(
f'''
"shuffle_data" is no longer supported. '
To be general, users should shuffle their data by themself.
See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1
for the solution of how to fix this.
'''
)
if seed is not None:
NoLongerSupportError('"seed" is no longer supported. '
'Please shuffle your data by yourself.')
# jit settings
if isinstance(self._origin_jit, bool):
self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, self._origin_jit)
self.jit[c.LOSS_PHASE] = self.jit.get(c.LOSS_PHASE, self._origin_jit)
self.jit[c.FIT_PHASE] = self.jit.get(c.FIT_PHASE, self._origin_jit)
else:
self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True)
self.jit[c.LOSS_PHASE] = self._origin_jit.get(c.LOSS_PHASE, True)
self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True)
# optimizer
if optimizer is None:
lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
optimizer = optim.Adam(lr=lr)
self.optimizer: optim.Optimizer = optimizer
if len(self.optimizer.vars_to_train) == 0:
self.optimizer.register_train_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique())
# loss function
self.loss_has_aux = loss_has_aux
if isinstance(loss_fun, str):
loss_fun = getattr(losses, loss_fun)
elif callable(loss_fun):
loss_fun = loss_fun
else:
raise UnsupportedError(f'Do not support {type(loss_fun)} to specify the loss function. '
f'We only support str and callable function.')
self._loss_func = loss_fun
self.loss_auto_run = loss_auto_run
# loss data
self._report_train_metrics = dict()
self._report_test_metrics = dict()
self._detailed_train_metrics = dict()
self._detailed_test_metrics = dict()
# functions
self._jit_step_func_grad = bm.jit(self._step_func_grad, static_argnums=(0,))
self._jit_step_func_loss = bm.jit(self._step_func_loss, static_argnums=(0,))
self._jit_step_func_fit = bm.jit(self._step_func_fit, static_argnums=(0,))
def __repr__(self):
name = self.__class__.__name__
prefix = ' ' * len(name)
return (f'{name}(target={self.target}, \n\t'
f'{prefix}jit={self.jit}, \n\t'
f'{prefix}loss={self._loss_func}, \n\t'
f'{prefix}optimizer={self.optimizer})')
def get_hist_metric(self, phase='fit', metric='loss', which='report'):
"""Get history losses."""
assert phase in [c.FIT_PHASE, c.TEST_PHASE, c.TRAIN_PHASE, c.PREDICT_PHASE]
assert which in ['report', 'detailed']
if phase in [c.FIT_PHASE, c.TRAIN_PHASE]:
if which == 'report':
return self._report_train_metrics.get(metric, None)
elif which == 'detailed':
return self._detailed_train_metrics.get(metric, None)
elif phase in [c.TEST_PHASE, c.PREDICT_PHASE]:
if which == 'report':
return self._report_test_metrics.get(metric, None)
elif which == 'detailed':
return self._detailed_test_metrics.get(metric, None)
@property
def train_losses(self):
return self.get_hist_metric(phase='fit')
@property
def test_losses(self):
return self.get_hist_metric(phase='test')
def fit(
self,
train_data: Union[Callable, Iterable],
test_data: Optional[Union[Callable, Iterable]] = None,
num_epoch: int = 100,
num_report: int = -1,
reset_state: bool = True,
shared_args: Optional[Dict] = None,
fun_after_report: Optional[Callable] = None,
# ------
# API deprecated
batch_size: int = None,
):
"""Fit the target model according to the given training data.
Parameters
----------
train_data: callable, iterable
It can be a callable function, or a tuple/list representing `(X, Y)` data.
- Callable. This function should return a pair of `(X, Y)` data.
- Iterable. It should be a pair of `(X, Y)` train set.
- ``X``: should be a tensor or a dict of tensors with the shape of
`(num_sample, num_time, ...)`, where `num_sample` is
the number of samples, `num_time` is the number of the time step,
and `num_feature` is the number of features.
- ``Y``: Target values. A tensor or a dict of tensors.
- If the shape of each tensor is `(num_sample, num_feature)`,
then we will only fit the model with the only last output.
- If the shape of each tensor is `(num_sample, num_time, num_feature)`,
then the fitting happens on the whole data series.
test_data: callable, iterable, optional
Same as ``train_data``.
num_epoch: int
The number of training epoch. Default 100.
num_report: int
The number of step to report the progress.
If `num_report=-1`, it will report the training progress each epoch.
reset_state: bool
Whether reset the initial states of the target model.
shared_args: dict
The shared keyword arguments for the target models.
fun_after_report: optional, Callable
The function to call after each report of `fit` phase or `test` phase.
The function should receive three arguments:
- ``idx`` for the indicator the current the running index. (If ``report=-1``,
The running index is the epoch. Otherwise, is the 'fit_idx' for 'fit' phase
and 'test_idx' for 'test' phase).
- ``metrics``: the metrics defined in the loss function
- ``phase``: to indicate the phase of 'fit' or 'test'.
.. versionadded:: 2.3.1
batch_size: int
.. deprecated:: 2.2.4.1
Please set batch size in your dataset.
"""
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
shared_args = tools.DotDict(shared_args)
if batch_size is not None:
raise NoLongerSupportError('Please set batch size in your data. '
'Specifically, make an iterable dataset '
'which return a batch of (X, Y) data.')
if isinstance(train_data, (tuple, list)):
if len(train_data) == 2:
raise UnsupportedError(msg)
if fun_after_report is not None:
assert callable(fun_after_report), ('\n'
'Unknown "fun_after_report", '
'it should be a callable function receiving '
'three arguments: idx, metrics, phase')
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
true_progress_bar = self.progress_bar
self.progress_bar = False
# training the model
detailed_train_metric = dict()
report_train_metric = dict()
detailed_test_metric = dict()
report_test_metric = dict()
fit_i, fit_t = 0, 0
test_i, test_t = 0, 0
for epoch_idx in range(num_epoch):
# training set
fit_t0 = time.time()
fit_epoch_metric = dict(loss=[])
_training_data = train_data() if callable(train_data) else train_data
if hasattr(_training_data, '__len__'):
bar = tqdm(total=len(_training_data))
else:
bar = None
for x, y in _training_data:
# reset state
if reset_state:
self.target.reset_state(self._get_input_batch_size(x))
self.reset_state()
# training
res = self.f_train(shared_args, x, y)
# loss
fit_epoch_metric['loss'].append(res[0])
if self.loss_has_aux:
if not isinstance(res[1], dict):
raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}')
for k, v in res[1].items():
if k not in fit_epoch_metric:
fit_epoch_metric[k] = []
fit_epoch_metric[k].append(v)
if bar is not None:
bar.update(1)
# report
fit_i += 1
if num_report > 0 and fit_i % num_report == 0:
fit_t1 = time.time()
aux = {}
for k, v in fit_epoch_metric.items():
aux[k] = jnp.mean(bm.as_jax(bm.asarray(v)))
if k not in report_train_metric:
report_train_metric[k] = []
detailed_train_metric[k] = []
report_train_metric[k].append(aux[k])
detailed_train_metric[k].extend(v)
v.clear()
_report = (f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' +
', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()])))
if bar is not None:
bar.set_description(_report, refresh=True)
else:
print(_report)
if fun_after_report is not None:
fun_after_report(fit_i, aux, 'fit')
fit_t0 = time.time()
fit_t = 0
if num_report <= 0:
fit_t1 = time.time()
aux = {}
for k, v in fit_epoch_metric.items():
aux[k] = np.mean(np.asarray(v))
if k not in report_train_metric:
report_train_metric[k] = []
detailed_train_metric[k] = []
report_train_metric[k].append(aux[k])
detailed_train_metric[k].extend(v)
v.clear()
_report = (f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' +
', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()])))
if bar is not None:
bar.set_description(_report, refresh=True)
else:
print(_report)
if fun_after_report is not None:
fun_after_report(epoch_idx, aux, 'fit')
else:
fit_t = time.time() - fit_t0
self.optimizer.lr.step_epoch()
if bar is not None: bar.close()
# testing set
if test_data is not None:
test_t0 = time.time()
test_epoch_metric = dict(loss=[])
_testing_data = test_data() if callable(test_data) else test_data
if hasattr(_testing_data, '__len__'):
bar = tqdm(total=len(_testing_data))
else:
bar = None
for x, y in _testing_data:
# reset state
if reset_state:
self.target.reset_state(self._get_input_batch_size(x))
self.reset_state()
# testing
res = self.f_loss(shared_args, x, y)
# loss
if self.loss_has_aux:
test_epoch_metric['loss'].append(res[0])
if not isinstance(res[1], dict):
raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}')
for k, v in res[1].items():
if k not in test_epoch_metric:
test_epoch_metric[k] = []
test_epoch_metric[k].append(v)
else:
test_epoch_metric['loss'].append(res)
if bar is not None: bar.update(1)
# report
test_i += 1
if num_report > 0 and test_i % num_report == 0:
test_t1 = time.time()
aux = {}
for k, v in test_epoch_metric.items():
aux[k] = np.mean(np.asarray(v))
if k not in report_test_metric:
report_test_metric[k] = []
detailed_test_metric[k] = []
report_test_metric[k].append(aux[k])
detailed_test_metric[k].extend(v)
v.clear()
_report = (f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' +
', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()])))
if bar is not None:
bar.set_description(_report, refresh=True)
else:
print(_report)
if fun_after_report is not None:
fun_after_report(test_i, aux, 'test')
test_t0 = time.time()
test_t = 0
if num_report <= 0:
test_t1 = time.time()
aux = {}
for k, v in test_epoch_metric.items():
aux[k] = jnp.mean(bm.as_jax(bm.asarray(v)))
if k not in report_test_metric:
report_test_metric[k] = []
detailed_test_metric[k] = []
report_test_metric[k].append(aux[k])
detailed_test_metric[k].extend(v)
v.clear()
_report = (f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' +
', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()])))
if bar is not None:
bar.set_description(_report, refresh=True)
else:
print(_report)
if fun_after_report is not None:
fun_after_report(epoch_idx, aux, 'test')
else:
test_t = time.time() - test_t0
if bar is not None: bar.close()
# finally
self._report_train_metrics = {k: np.asarray(v) for k, v in report_train_metric.items()}
self._detailed_train_metrics = {k: np.asarray(v) for k, v in detailed_train_metric.items()}
self._report_test_metrics = {k: np.asarray(v) for k, v in report_test_metric.items()}
self._detailed_test_metrics = {k: np.asarray(v) for k, v in detailed_test_metric.items()}
self.progress_bar = true_progress_bar
def _step_func_grad(self, shared_args, inputs, targets):
tran_vars = self.target.train_vars().unique()
grad_f = bm.grad(self._step_func_loss,
grad_vars=tran_vars,
return_value=True,
has_aux=self.loss_has_aux)
return grad_f(shared_args, inputs, targets)
def _step_func_loss(self, shared_args, inputs, targets):
raise NotImplementedError
@property
def f_loss(self):
return self._jit_step_func_loss if self.jit[c.LOSS_PHASE] else self._step_func_loss
def _step_func_fit(self, shared_args, inputs, targets):
raise NotImplementedError
@property
def f_train(self):
return self._jit_step_func_fit if self.jit[c.FIT_PHASE] else self._step_func_fit
@property
def f_grad(self):
return self._jit_step_func_grad if self.jit[c.FIT_PHASE] else self._step_func_grad
[docs]class BPTT(BPTrainer):
"""The trainer implementing the back-propagation through time (BPTT)
algorithm for training dyamical systems.
For more parameters, users should refer to :py:class:`~.DSRunner`.
Parameters
----------
target: DynamicalSystem
The target model to train.
loss_fun: str, callable
The loss function.
- If it is a string, it should be the function chosen from ``brainpy.losses`` module.
- Otherwise, a callable function which receives argument of ``(predicts, targets)``
should be provided.
.. note::
If ``monitors`` has been set in the trainer, the ``predicts`` contains two
parts: the network history prediction outputs, and the monitored values.
see BrainPy examples for more information.
loss_has_aux: bool
To indicate whether the loss function returns auxiliary data expect the loss.
Moreover, all auxiliary data should be a dict, whose key is used for logging
item name and its data is used for the corresponding value.
For example,
.. code-block:: python
def loss_fun(predicts, targets):
return loss, {'acc': acc, 'spike_num': spike_num}
optimizer: Optimizer
The optimizer used for training. Should be an instance of :py:class:`~.Optimizer`.
numpy_mon_after_run: bool
Make the monitored results as NumPy arrays.
logger: Any
A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`.
data_first_axis: str
To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the
time length (``data_first_axis='T'``).
"""
def _step_func_loss(self, shared_args, inputs, targets):
num_step = self._get_input_time_step(xs=inputs)
indices = np.arange(self.i0, self.i0 + num_step, dtype=np.int_)
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.Array))
if not isinstance(inputs, (tuple, list)):
inputs = (inputs,)
outs, mons = self._predict(indices, *inputs, shared_args=shared_args)
predicts = (outs, mons) if len(mons) > 0 else outs
return self._loss_func(predicts, targets)
def _step_func_fit(self, shared_args, inputs, targets):
res = self.f_grad(shared_args, inputs, targets)
self.optimizer.update(res[0])
return res[1:]
[docs]class BPFF(BPTrainer):
"""
The trainer implementing back propagation algorithm
for feedforward neural networks.
For more parameters, users should refer to :py:class:`~.DSRunner`.
"""
def _step_func_loss(self, shared_args, inputs, targets):
if not isinstance(inputs, (tuple, list)):
inputs = (inputs,)
outputs, mon = self._step_func_predict(*inputs, shared_args=shared_args)
outs = (outputs, mon) if len(mon) > 0 else outputs
loss = self._loss_func(outs, targets)
return loss
def _step_func_fit(self, shared_args, inputs, targets):
res = self.f_grad(shared_args, inputs, targets)
self.optimizer.update(res[0])
return res[1:]
def _step_func_predict(self, *x, shared_args=None):
assert self.data_first_axis == 'B', (f'There is no time dimension when '
f'using the trainer {self.__class__.__name__}.')
if shared_args is not None:
assert isinstance(shared_args, dict)
share.save(**shared_args)
share.save(dt=self.dt)
# input step
clear_input(self.target)
self._step_func_input()
# dynamics update step
out = self.target(*x)
# monitor step
mon = self._step_func_monitor()
# share.clear_shargs()
return out, mon
def _fun_predict(self, *inputs, shared_args=None):
if self.jit['predict']:
return self._jit_step_func_predict(*inputs, shared_args=shared_args)
else:
return self._step_func_predict(*inputs, shared_args=shared_args)
[docs] def predict(
self,
inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]],
reset_state: bool = True,
shared_args: Dict = None,
eval_time: bool = False
) -> Output:
"""Predict a series of input data with the given target model.
This function use the JIT compilation to accelerate the model simulation.
Moreover, it can automatically monitor the node variables, states, inputs,
feedbacks and its output.
Parameters
----------
inputs: ArrayType, dict
The feedforward input data. It must be a 3-dimensional data
which has the shape of `(num_sample, num_time, num_feature)`.
reset_state: bool
Whether reset the model states.
shared_args: optional, dict
The shared arguments across different layers.
eval_time: bool
Evaluate the time used for running.
Returns
-------
output: ArrayType, dict
The model output.
"""
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)
shared_args = tools.DotDict(shared_args)
# reset the model states
if reset_state:
self.target.reset_state(self._get_input_batch_size(xs=inputs))
self.reset_state()
# init monitor
for key in self._monitors.keys():
self.mon[key] = [] # reshape the monitor items
# prediction
if not isinstance(inputs, (tuple, list)):
inputs = (inputs,)
if eval_time: t0 = time.time()
outs, hists = self._fun_predict(*inputs, shared_args=shared_args)
if eval_time: t1 = time.time()
# post-running for monitors
for key in hists.keys():
self.mon[key] = bm.asarray(hists[key])
if self.numpy_mon_after_run:
for key in hists.keys():
self.mon[key] = np.asarray(self.mon[key])
return (t1 - t0, outs) if eval_time else outs