Source code for brainpy._src.integrators.base
# -*- coding: utf-8 -*-
from typing import Dict, Sequence, Union
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math import TimeDelay, LengthDelay
from brainpy.check import is_float, is_dict_data
from brainpy.errors import DiffEqError
from .constants import DT
__all__ = [
'Integrator',
]
class AbstractIntegrator(BrainPyObject):
"""Basic Integrator Class."""
# func_name
# derivative
# code_scope
#
def __call__(self, *args, **kwargs):
raise NotImplementedError
[docs]
class Integrator(AbstractIntegrator):
"""Basic Integrator Class."""
def __init__(
self,
variables: Sequence[str],
parameters: Sequence[str],
arguments: Sequence[str],
dt: float,
name: str = None,
state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None,
):
super(Integrator, self).__init__(name=name)
self._dt = dt
is_float(dt, 'dt', allow_none=False, allow_int=True)
self._variables = list(variables) # variables
self._parameters = list(parameters) # parameters
self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments
self._integral = None # integral function
self.arg_names = self._variables + self._parameters + [DT]
# state delays
self._state_delays = dict()
if state_delays is not None:
is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay))
for key, delay in state_delays.items():
if key not in self.variables:
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}')
self._state_delays[key] = delay
self.register_implicit_nodes(self._state_delays)
@property
def dt(self):
"""The numerical integration precision."""
return self._dt
@dt.setter
def dt(self, value):
raise ValueError('Cannot set "dt" by users.')
@property
def variables(self):
"""The variables defined in the differential equation."""
return self._variables
@variables.setter
def variables(self, values):
raise ValueError('Cannot set "variables" by users.')
@property
def parameters(self):
"""The parameters defined in the differential equation."""
return self._parameters
@parameters.setter
def parameters(self, values):
raise ValueError('Cannot set "parameters" by users.')
@property
def arguments(self):
"""All arguments when calling the numer integrator of the differential equation."""
return self._arguments
@arguments.setter
def arguments(self, values):
raise ValueError('Cannot set "arguments" by users.')
@property
def integral(self):
"""The integral function."""
return self._integral
@integral.setter
def integral(self, f):
self.set_integral(f)
[docs]
def set_integral(self, f):
"""Set the integral function."""
if not callable(f):
raise ValueError(f'integral function must be a callable function, '
f'but we got {type(f)}: {f}')
self._integral = f
@property
def state_delays(self):
"""State delays."""
return self._state_delays
@state_delays.setter
def state_delays(self, value):
raise ValueError('Cannot set "state_delays" by users.')
def __call__(self, *args, **kwargs):
assert self.integral is not None, 'Please build the integrator first.'
# check arguments
for i, arg in enumerate(args):
kwargs[self.arg_names[i]] = arg
# integral
new_vars = self.integral(**kwargs)
if len(self.variables) == 1:
dict_vars = {self.variables[0]: new_vars}
else:
dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)}
# update state delay variables
dt = kwargs.pop(DT, self.dt)
for key, delay in self.state_delays.items():
if isinstance(delay, LengthDelay):
delay.update(dict_vars[key])
elif isinstance(delay, TimeDelay):
delay.update(dict_vars[key])
else:
raise ValueError('Unknown delay variable. We only supports '
'brainpy.math.LengthDelay, brainpy.math.TimeDelay. '
f'While we got {delay}')
return new_vars