Source code for brainpy._src.train.base
# -*- coding: utf-8 -*-
from typing import Dict, Sequence, Any, Optional
import brainpy.math as bm
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.runners import DSRunner
from brainpy._src.running import constants as c
from brainpy.errors import NoLongerSupportError
from brainpy.types import Output
__all__ = [
'DSTrainer',
]
[docs]
class DSTrainer(DSRunner):
"""Structural Trainer for Dynamical Systems.
For more parameters, users should refer to :py:class:`~.DSRunner`.
Parameters
----------
target: DynamicalSystem
The training target.
kwargs: Any
Other general parameters in :py:class:`~.DSRunner`.
"""
target: DynamicalSystem
'''The training target.'''
train_nodes: Sequence[DynamicalSystem] # need to be initialized by subclass
'''All children nodes in this training target.'''
def __init__(
self,
target: DynamicalSystem,
**kwargs
):
super().__init__(target=target, **kwargs)
if not isinstance(self.target.mode, bm.BatchingMode):
raise NoLongerSupportError(f'''
From version 2.3.1, DSTrainer must receive a DynamicalSystem instance with
the computing mode of {bm.batching_mode} or {bm.training_mode}.
See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1
for the solution of how to fix this.
''')
# jit
if isinstance(self._origin_jit, bool):
self.jit[c.PREDICT_PHASE] = self._origin_jit
self.jit[c.FIT_PHASE] = self._origin_jit
else:
self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True)
self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True)
[docs]
def predict(
self,
inputs: Any,
reset_state: bool = False,
shared_args: Optional[Dict] = None,
eval_time: bool = False
) -> Output:
"""Prediction function.
Parameters
----------
inputs: ArrayType, sequence of ArrayType, dict of ArrayType
The input values.
reset_state: bool
Reset the target state before running.
eval_time: bool
Whether we evaluate the running time or not?
shared_args: dict
The shared arguments across nodes.
Returns
-------
output: ArrayType, sequence of ArrayType, dict of ArrayType
The running output.
"""
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)
return super().predict(inputs=inputs,
reset_state=reset_state,
shared_args=shared_args,
eval_time=eval_time)
def fit(
self,
train_data: Any,
reset_state: bool = False,
shared_args: Dict = None
) -> Output: # need to be implemented by subclass
raise NotImplementedError('Must implement the fit function. ')