Source code for brainpy.train.base

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict, Sequence, Any, Optional

import brainpy.math as bm
from brainpy._errors import NoLongerSupportError
from brainpy.dynsys import DynamicalSystem
from brainpy.runners import DSRunner
from brainpy.running import constants as c
from brainpy.types import Output

__all__ = [
    'DSTrainer',
]


[docs] class DSTrainer(DSRunner): """Structural Trainer for Dynamical Systems. For more parameters, users should refer to :py:class:`~.DSRunner`. Parameters:: target: DynamicalSystem The training target. kwargs: Any Other general parameters in :py:class:`~.DSRunner`. """ target: DynamicalSystem '''The training target.''' train_nodes: Sequence[DynamicalSystem] # need to be initialized by subclass '''All children nodes in this training target.''' def __init__( self, target: DynamicalSystem, **kwargs ): super().__init__(target=target, **kwargs) if not isinstance(self.target.mode, bm.BatchingMode): raise NoLongerSupportError(f''' From version 2.3.1, DSTrainer must receive a DynamicalSystem instance with the computing mode of {bm.batching_mode} or {bm.training_mode}. See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1 for the solution of how to fix this. ''') # jit if isinstance(self._origin_jit, bool): self.jit[c.PREDICT_PHASE] = self._origin_jit self.jit[c.FIT_PHASE] = self._origin_jit else: self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True) self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True)
[docs] def predict( self, inputs: Any, reset_state: bool = False, shared_args: Optional[Dict] = None, eval_time: bool = False ) -> Output: """Prediction function. Parameters:: inputs: ArrayType, sequence of ArrayType, dict of ArrayType The input values. reset_state: bool Reset the target state before running. eval_time: bool Whether we evaluate the running time or not? shared_args: dict The shared arguments across nodes. Returns:: output: ArrayType, sequence of ArrayType, dict of ArrayType The running output. """ if shared_args is None: shared_args = dict() shared_args['fit'] = shared_args.get('fit', False) return super().predict(inputs=inputs, reset_state=reset_state, shared_args=shared_args, eval_time=eval_time)
def fit( self, train_data: Any, reset_state: bool = False, shared_args: Dict = None ) -> Output: # need to be implemented by subclass raise NotImplementedError('Must implement the fit function. ')