Source code for brainpy._src.integrators.ode.base

# -*- coding: utf-8 -*-


from typing import Dict, Callable, Union

from brainpy.errors import DiffEqError, CodeError
from brainpy._src import math as bm
from brainpy._src.integrators import constants, utils
from brainpy._src.integrators.base import Integrator
from brainpy._src.integrators.constants import DT
from brainpy.check import is_dict_data

__all__ = [
  'ODEIntegrator',
]


def f_names(f):
  func_name = constants.unique_name('ode')
  if f.__name__.isidentifier():
    func_name += '_' + f.__name__
  return func_name


[docs] class ODEIntegrator(Integrator): """Numerical Integrator for Ordinary Differential Equations (ODEs). Parameters ---------- f : callable The derivative function. var_type: str The type for each variable. dt: float, int The numerical precision. name: str The integrator name. """ def __init__( self, f: Callable, var_type: str = None, dt: float = None, name: str = None, show_code: bool = False, state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, neutral_delays: Dict[str, Union[bm.NeuTimeDelay, bm.NeuLenDelay]] = None ): dt = bm.get_dt() if dt is None else dt parses = utils.get_args(f) variables = parses[0] # variable names, (before 't') parameters = parses[1] # parameter names, (after 't') arguments = parses[2] # function arguments for p in tuple(variables) + tuple(parameters): if p == DT: raise CodeError(f'{DT} is a system keyword denotes the ' f'precision of numerical integration. ' f'It cannot be used as a variable or parameter, ' f'please change an another name.') # super initialization super(ODEIntegrator, self).__init__(name=name, variables=variables, parameters=parameters, arguments=arguments, dt=dt, state_delays=state_delays) # others self.show_code = show_code self.var_type = var_type # variable type # derivative function self.derivative = {constants.F: f} self.f = f # code scope self.code_scope = {constants.F: f} # code lines self.func_name = f_names(f) self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] # neutral delays self._neutral_delays = dict() if neutral_delays is not None: is_dict_data(neutral_delays, key_type=str, val_type=bm.NeuTimeDelay) for key, delay in neutral_delays.items(): if key not in self.variables: raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') self._neutral_delays[key] = delay self.register_implicit_nodes(self._neutral_delays) @property def neutral_delays(self): """neutral delays.""" return self._neutral_delays @neutral_delays.setter def neutral_delays(self, value): raise ValueError('Cannot set "neutral_delays" by users.') def __call__(self, *args, **kwargs): assert self.integral is not None, 'Please build the integrator first.' # check arguments for i, arg in enumerate(args): kwargs[self.arg_names[i]] = arg # integral new_vars = self.integral(**kwargs) if len(self.variables) == 1: dict_vars = {self.variables[0]: new_vars} else: dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} dt = kwargs.pop(DT, self.dt) # update neutral delay variables if len(self.neutral_delays): kwargs.update(dict_vars) new_devs = self.f(**kwargs) if len(self.variables) == 1: new_devs = {self.variables[0]: new_devs} else: new_devs = {k: new_devs[i] for i, k in enumerate(self.variables)} for key, delay in self.neutral_delays.items(): if isinstance(delay, bm.NeuLenDelay): delay.update(new_devs[key]) elif isinstance(delay, bm.NeuTimeDelay): delay.update(kwargs['t'] + dt, new_devs[key]) else: raise ValueError('Unknown delay variable. We only supports ' f'{bm.NeuTimeDelay.__name__} and {bm.NeuLenDelay.__name__}. ' f'While we got {delay}') # update state delay variables for key, delay in self.state_delays.items(): if isinstance(delay, bm.LengthDelay): delay.update(dict_vars[key]) elif isinstance(delay, bm.TimeDelay): delay.update(dict_vars[key]) else: raise ValueError('Unknown delay variable. We only supports ' f'{bm.LengthDelay.__name__} and {bm.TimeDelay.__name__}. ' f'While we got {delay}') return new_vars