Source code for brainpy._src.integrators.sde.base
# -*- coding: utf-8 -*-
from typing import Dict, Callable, Union, Sequence
import jax.numpy as jnp
from brainpy import errors
from brainpy._src import math as bm
from brainpy._src.integrators import constants, utils
from brainpy._src.integrators.base import Integrator
from brainpy._src.math.delayvars import AbstractDelay
__all__ = [
'SDEIntegrator',
]
def f_names(f):
func_name = constants.unique_name('sde')
if f.__name__.isidentifier():
func_name += '_' + f.__name__
return func_name
[docs]
class SDEIntegrator(Integrator):
"""SDE Integrator."""
def __init__(
self,
f: Callable,
g: Callable,
dt: float = None,
name: str = None,
show_code: bool = False,
var_type: str = None,
intg_type: str = None,
wiener_type: str = None,
state_delays: Dict[str, AbstractDelay] = None,
):
dt = bm.get_dt() if dt is None else dt
parses = utils.get_args(f)
variables = parses[0] # variable names, (before 't')
parameters = parses[1] # parameter names, (after 't')
arguments = parses[2] # function arguments
# super initialization
super(SDEIntegrator, self).__init__(name=name,
variables=variables,
parameters=parameters,
arguments=arguments,
dt=dt,
state_delays=state_delays)
# derivative functions
self.derivative = {constants.F: f, constants.G: g}
self.f = f
self.g = g
# essential parameters
intg_type = constants.ITO_SDE if intg_type is None else intg_type
var_type = constants.SCALAR_VAR if var_type is None else var_type
wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type
if intg_type not in constants.SUPPORTED_INTG_TYPE:
raise errors.IntegratorError(f'Currently, BrainPy only support SDE_INT types: '
f'{constants.SUPPORTED_INTG_TYPE}. But we got {intg_type}.')
if var_type not in constants.SUPPORTED_VAR_TYPE:
raise errors.IntegratorError(f'Currently, BrainPy only supports variable types: '
f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.')
if wiener_type not in constants.SUPPORTED_WIENER_TYPE:
raise errors.IntegratorError(f'Currently, BrainPy only supports Wiener '
f'Process types: {constants.SUPPORTED_WIENER_TYPE}. '
f'But we got {wiener_type}.')
self.var_type = var_type # variable type
self.intg_type = intg_type # integral type
self.wiener_type = wiener_type # wiener process type
# code scope
self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': bm.random.DEFAULT}
# code lines
self.func_name = f_names(f)
self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):']
# others
self.show_code = show_code
def _check_vector_wiener_dim(self, noise_size, var_size):
if noise_size[:-1] > var_size[-len(noise_size) +1:]:
raise ValueError(f"Incompatible shapes for shapes of noise {noise_size} and variable {var_size}")