Source code for brainpy.integrators.base

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from contextlib import contextmanager
from typing import Dict, Sequence, Union, Callable

import jax
from brainstate.transform import jaxpr_to_python_code

from brainpy._errors import DiffEqError
from brainpy.check import is_float, is_dict_data
from brainpy.math import TimeDelay, LengthDelay
from brainpy.math.object_transform.base import BrainPyObject
from .constants import DT

__all__ = [
    'AbstractIntegrator',
    'Integrator',
    'compile_integrators',
]


class AbstractIntegrator(BrainPyObject):
    """Basic Integrator Class."""

    # func_name
    # derivative
    # code_scope
    #

    def __call__(self, *args, **kwargs):
        raise NotImplementedError


[docs] class Integrator(AbstractIntegrator): """Basic Integrator Class.""" def __init__( self, variables: Sequence[str], parameters: Sequence[str], arguments: Sequence[str], dt: float, name: str = None, state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None, ): super(Integrator, self).__init__(name=name) self._dt = dt is_float(dt, 'dt', allow_none=False, allow_int=True) self._variables = list(variables) # variables self._parameters = list(parameters) # parameters self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments self._integral = None # integral function self.arg_names = self._variables + self._parameters + [DT] # state delays self._state_delays = dict() if state_delays is not None: is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay)) for key, delay in state_delays.items(): if key not in self.variables: raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') self._state_delays[key] = delay self.register_implicit_nodes(self._state_delays) # math expression self._math_expr = None @property def dt(self): """The numerical integration precision.""" return self._dt @dt.setter def dt(self, value): raise ValueError('Cannot set "dt" by users.') @property def variables(self): """The variables defined in the differential equation.""" return self._variables @variables.setter def variables(self, values): raise ValueError('Cannot set "variables" by users.') @property def parameters(self): """The parameters defined in the differential equation.""" return self._parameters @parameters.setter def parameters(self, values): raise ValueError('Cannot set "parameters" by users.') @property def arguments(self): """All arguments when calling the numer integrator of the differential equation.""" return self._arguments @arguments.setter def arguments(self, values): raise ValueError('Cannot set "arguments" by users.') @property def integral(self): """The integral function.""" return self._integral @integral.setter def integral(self, f): self.set_integral(f)
[docs] def set_integral(self, f): """Set the integral function.""" if not callable(f): raise ValueError(f'integral function must be a callable function, ' f'but we got {type(f)}: {f}') self._integral = f
@property def state_delays(self): """State delays.""" return self._state_delays @state_delays.setter def state_delays(self, value): raise ValueError('Cannot set "state_delays" by users.') def _call_integral(self, *args, **kwargs): kwargs = dict(kwargs) t = kwargs.get('t', None) kwargs['t'] = 0. if t is None else t if _during_compile: jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) _, tree = jax.tree.flatten(out_shapes) new_vars = tree.unflatten(outs) self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) else: new_vars = self.integral(**kwargs) return new_vars def __call__(self, *args, **kwargs): assert self.integral is not None, 'Please build the integrator first.' # check arguments for i, arg in enumerate(args): kwargs[self.arg_names[i]] = arg # integral new_vars = self._call_integral(**kwargs) # post-process if len(self.variables) == 1: dict_vars = {self.variables[0]: new_vars} else: dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} # update state delay variables dt = kwargs.pop(DT, self.dt) for key, delay in self.state_delays.items(): if isinstance(delay, LengthDelay): delay.update(dict_vars[key]) elif isinstance(delay, TimeDelay): delay.update(dict_vars[key]) else: raise ValueError('Unknown delay variable. We only supports ' 'brainpy.math.LengthDelay, brainpy.math.TimeDelay. ' f'While we got {delay}') return new_vars def to_math_expr(self): if self._math_expr is None: raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') return self._math_expr
_during_compile = False @contextmanager def _during_compile_context(): global _during_compile try: _during_compile = True yield finally: _during_compile = False def compile_integrators(f: Callable, *args, **kwargs): """ Compile integrators in the given function. """ with _during_compile_context(): return f(*args, **kwargs)