Source code for brainpy._src.integrators.ode.generic

# -*- coding: utf-8 -*-

from typing import Dict

from brainpy._src.math.delayvars import AbstractDelay, NeuTimeDelay
from .base import ODEIntegrator

__all__ = [
  'set_default_odeint',
  'get_default_odeint',
  'register_ode_integrator',
  'get_supported_methods',
]

name2method = {
}

_DEFAULT_DDE_METHOD = 'euler'


[docs] def odeint( f=None, method=None, var_type=None, dt=None, name=None, show_code=False, state_delays: Dict[str, AbstractDelay] = None, neutral_delays: Dict[str, NeuTimeDelay] = None, **kwargs ): """Numerical integration for ODEs. Examples -------- .. plot:: :include-source: True >>> import brainpy as bp >>> import matplotlib.pyplot as plt >>> >>> a=0.7; b=0.8; tau=12.5; Vth=1.9 >>> V = 0; w = 0 # initial values >>> >>> @bp.odeint(method='rk4', dt=0.04) >>> def integral(V, w, t, Iext): >>> dw = (V + a - b * w) / tau >>> dV = V - V * V * V / 3 - w + Iext >>> return dV, dw >>> >>> hist_V = [] >>> for t in bp.math.arange(0, 100, integral.dt): >>> V, w = integral(V, w, t, 0.5) >>> hist_V.append(V) >>> plt.plot(bp.math.arange(0, 100, integral.dt), hist_V) >>> plt.show() Parameters ---------- f : callable, function The derivative function. method : str The shortcut name of the numerical integrator. var_type: str The type of the variable defined in the equation. dt: float The numerical integration precision. name: str The integrator node. state_delays: dict The state delay variable. show_code: bool Show the formated code. adaptive: bool The use adaptive mode. tol: float The tolerence to adapt new step size. Returns ------- integral : ODEIntegrator The numerical solver of `f`. """ method = _DEFAULT_DDE_METHOD if method is None else method if method not in name2method: raise ValueError(f'Unknown ODE numerical method "{method}". Currently ' f'BrainPy only support: {list(name2method.keys())}') if f is None: return lambda f: name2method[method](f, var_type=var_type, dt=dt, name=name, show_code=show_code, state_delays=state_delays, neutral_delays=neutral_delays, **kwargs) else: return name2method[method](f, var_type=var_type, dt=dt, name=name, show_code=show_code, state_delays=state_delays, neutral_delays=neutral_delays, **kwargs)
[docs] def set_default_odeint(method): """Set the default ODE numerical integrator method for differential equations. Parameters ---------- method : str, callable Numerical integrator method. """ if not isinstance(method, str): raise ValueError(f'Only support string, not {type(method)}.') if method not in name2method: raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') global _DEFAULT_DDE_METHOD _DEFAULT_ODE_METHOD = method
[docs] def get_default_odeint(): """Get the default ODE numerical integrator method. Returns ------- method : str The default numerical integrator method. """ return _DEFAULT_DDE_METHOD
[docs] def register_ode_integrator(name, integrator): """Register a new ODE integrator. Parameters ---------- name: ste integrator: type """ if name in name2method: raise ValueError(f'"{name}" has been registered in ODE integrators.') if not issubclass(integrator, ODEIntegrator): raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}') name2method[name] = integrator
[docs] def get_supported_methods(): """Get all supported numerical methods for DDEs.""" return list(name2method.keys())