Source code for brainpy.integrators.sde.generic

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict, Union

import brainpy.math as bm
from .base import SDEIntegrator

__all__ = [
    'set_default_sdeint',
    'get_default_sdeint',
    'register_sde_integrator',
    'get_supported_methods',
]

name2method = {
}

_DEFAULT_SDE_METHOD = 'euler'


[docs] def sdeint( f=None, g=None, method=None, dt: float = None, name: str = None, show_code: bool = False, var_type: str = None, intg_type: str = None, wiener_type: str = None, state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None ): """Numerical integration for SDEs. Parameters:: f : callable, function The derivative function. method : str The shortcut name of the numerical integrator. Returns:: integral : SDEIntegrator The numerical solver of `f`. """ method = _DEFAULT_SDE_METHOD if method is None else method if method not in name2method: raise ValueError(f'Unknown SDE numerical method "{method}". Currently ' f'BrainPy only support: {list(name2method.keys())}') if f is not None and g is not None: return name2method[method](f=f, g=g, dt=dt, name=name, show_code=show_code, var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, state_delays=state_delays) elif f is not None: return lambda g: name2method[method](f=f, g=g, dt=dt, name=name, show_code=show_code, var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, state_delays=state_delays) elif g is not None: return lambda f: name2method[method](f=f, g=g, dt=dt, name=name, show_code=show_code, var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, state_delays=state_delays) else: raise ValueError('Must provide "f" or "g".')
[docs] def set_default_sdeint(method): """Set the default SDE numerical integrator method for differential equations. Parameters:: method : str, callable Numerical integrator method. """ if not isinstance(method, str): raise ValueError(f'Only support string, not {type(method)}.') if method not in name2method: raise ValueError(f'Unsupported SDE_INT numerical method: {method}.') global _DEFAULT_SDE_METHOD _DEFAULT_SDE_METHOD = method
[docs] def get_default_sdeint(): """Get the default SDE numerical integrator method. Returns:: method : str The default numerical integrator method. """ return _DEFAULT_SDE_METHOD
[docs] def register_sde_integrator(name, integrator): """Register a new SDE integrator. Parameters:: name: ste integrator: type """ if name in name2method: raise ValueError(f'"{name}" has been registered in SDE integrators.') if not issubclass(integrator, SDEIntegrator): raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}') name2method[name] = integrator
[docs] def get_supported_methods(): """Get all supported numerical methods for DDEs.""" return list(name2method.keys())