# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Callable, Dict, Sequence
import jax.numpy as jnp
from brainpy import math as bm
from brainpy.integrators import constants, utils, joint_eq
from brainpy.integrators.constants import DT
from brainpy.integrators.sde.base import SDEIntegrator
from brainpy.integrators.sde.generic import register_sde_integrator
from brainpy.integrators.utils import format_args
__all__ = [
'Euler',
'Heun',
'Milstein',
'MilsteinGradFree',
'ExponentialEuler',
]
def df_and_dg(code_lines, variables, parameters):
# 1. df
# df = f(x, t, *args)
all_df = [f'{var}_df' for var in variables]
code_lines.append(f' {", ".join(all_df)} = f({", ".join(variables + parameters)})')
# 2. dg
# dg = g(x, t, *args)
all_dg = [f'{var}_dg' for var in variables]
code_lines.append(f' {", ".join(all_dg)} = g({", ".join(variables + parameters)})')
code_lines.append(' ')
def dfdt(code_lines, variables):
for var in variables:
code_lines.append(f' {var}_dfdt = {var}_df * {constants.DT}')
code_lines.append(' ')
def noise_terms(code_lines, variables):
for var in variables:
code_lines.append(f' if {var}_dg is not None:')
code_lines.append(f' {var}_dW = random.normal(0.000, dt_sqrt, math.shape({var})).value')
code_lines.append(' ')
[docs]
class Euler(SDEIntegrator):
r"""Euler method for the Ito and Stratonovich integrals.
For Ito schema, the Euler method (also called as Euler-Maruyama method) is given by:
.. math::
\begin{aligned}
Y_{n+1} &=Y_{n}+f\left(Y_{n}\right) h_{n}+g\left(Y_{n}\right) \Delta W_{n} \\
\Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
As the order of convergence for the Euler-Maruyama method is low (strong
order of convergence 0.5, weak order of convergence 1), the numerical results
are inaccurate unless a small step size is used. In fact, Euler-Maruyama
represents the order 0.5 strong Taylor scheme.
For Stratonovich scheme, the Euler-Heun method has to be used instead of the Euler-Maruyama method
.. math::
\begin{aligned}
Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\
\bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\
\Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
See Also::
Heun
"""
def __init__(
self, f, g, dt=None, name=None, show_code=False,
var_type=None, intg_type=None, wiener_type=None,
state_delays=None,
):
super(Euler, self).__init__(f=f, g=g, dt=dt, name=name,
var_type=var_type, intg_type=intg_type,
wiener_type=wiener_type,
state_delays=state_delays)
self.set_integral(self.step)
def step(self, *args, **kwargs):
all_args = format_args(args, kwargs, self.arg_names)
dt = all_args.pop(DT, self.dt)
# drift values
drifts = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(drifts, (bm.ndarray, jnp.ndarray)):
raise ValueError('Drift values must be a tensor when there '
'is only one variable in the equation.')
drifts = {self.variables[0]: drifts}
else:
if not isinstance(drifts, (tuple, list)):
raise ValueError('Drift values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
drifts = {var: drifts[i] for i, var in enumerate(self.variables)}
# diffusion values
diffusions = self.g(**all_args)
if len(self.variables) == 1:
# if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
# raise ValueError('Diffusion values must be a tensor when there '
# 'is only one variable in the equation.')
diffusions = {self.variables[0]: diffusions}
else:
if not isinstance(diffusions, (tuple, list)):
raise ValueError('Diffusion values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)}
if self.wiener_type == constants.VECTOR_WIENER:
for key, val in diffusions.items():
if val is not None and jnp.ndim(val) == 0:
raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple "
f"dimensional diffusion value. But we got a scale value for "
f"variable {key}.")
# integral results
integrals = []
if self.intg_type == constants.ITO_SDE:
for key in self.variables:
integral = all_args[key] + drifts[key] * dt
if diffusions[key] is not None:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.SCALAR_WIENER:
integral += diffusions[key] * bm.random.randn(*shape) * jnp.sqrt(dt)
else:
shape += jnp.shape(diffusions[key])[-1:]
integral += jnp.sum(diffusions[key] * bm.random.randn(*shape), axis=-1) * jnp.sqrt(dt)
integrals.append(integral)
else:
# \bar{Y}_{n}=Y_{n}+g_{n} \Delta W_{n}
all_args_bar = {key: val for key, val in all_args.items()}
all_noises = {}
for key in self.variables:
if diffusions[key] is None:
all_args_bar[key] = all_args[key]
else:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.VECTOR_WIENER:
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = bm.random.randn(*shape)
all_noises[key] = noise * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
y_bar = all_args[key] + jnp.sum(diffusions[key] * noise, axis=-1)
else:
y_bar = all_args[key] + diffusions[key] * noise
all_args_bar[key] = y_bar
# g(\bar{Y}_{n})
diffusion_bars = self.g(**all_args_bar)
if len(self.variables) == 1:
diffusion_bars = {self.variables[0]: diffusion_bars}
else:
diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)}
# Y_{n+1}=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n}
for key in self.variables:
integral = all_args[key] + drifts[key] * dt
if diffusion_bars[key] is not None:
integral += (diffusions[key] + diffusion_bars[key]) / 2 * all_noises[key]
integrals.append(integral)
# return integrals
if len(self.variables) == 1:
return integrals[0]
else:
return integrals
register_sde_integrator('euler', Euler)
[docs]
class Heun(Euler):
r"""The Euler-Heun method for Stratonovich integral scheme.
Its mathematical expression is given by
.. math::
\begin{aligned}
Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\
\bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\
\Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
See Also::
Euler
"""
def __init__(self, f, g, dt=None, name=None, show_code=False,
var_type=None, intg_type=None, wiener_type=None,
state_delays=None, ):
if intg_type != constants.STRA_SDE:
raise errors.IntegratorError(f'Heun method only supports Stranovich '
f'integral of SDEs, but we got {intg_type} integral.')
super(Heun, self).__init__(f=f, g=g, dt=dt, name=name,
var_type=var_type, intg_type=intg_type,
wiener_type=wiener_type, state_delays=state_delays)
register_sde_integrator('heun', Heun)
[docs]
class Milstein(SDEIntegrator):
r"""Milstein method for Ito or Stratonovich integrals.
The Milstein scheme represents the order 1.0 strong Taylor scheme. For the Ito integral,
.. math::
\begin{aligned}
&Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left[\left(\Delta W_{n}\right)^{2}-h\right] \\
&\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
where :math:`g_{n}^{\prime}=\frac{d g\left(Y_{n}\right)}{d Y_{n}}` is the first derivative of :math:`g_n`.
For the Stratonovich integral, the Milstein method is given by
.. math::
\begin{aligned}
&Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left(\Delta W_{n}\right)^{2} \\
&\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
"""
def __init__(
self,
f: Callable,
g: Callable,
dt: float = None,
name: str = None,
show_code=False,
var_type: str = None,
intg_type: str = None,
wiener_type: str = None,
state_delays: Dict[str, bm.AbstractDelay] = None,
):
super(Milstein, self).__init__(f=f,
g=g,
dt=dt,
name=name,
var_type=var_type,
intg_type=intg_type,
wiener_type=wiener_type,
state_delays=state_delays)
self.set_integral(self.step)
def _get_g_grad(self, f, allow_raise=False, need_grad=True):
if isinstance(f, joint_eq.JointEq):
results = []
state = True
for sub_eq in f.eqs:
r, r_state = self._get_g_grad(sub_eq, allow_raise, need_grad)
results.extend(r)
state &= r_state
return results, state
else:
res = [None, None, None]
state = True
try:
vars, pars, _ = utils.get_args(f)
if len(vars) != 1:
raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__,
vars=str(vars), eq=str(f)))
res[1] = vars
res[2] = pars
except errors.DiffEqError as e:
state = False
if not allow_raise:
raise e
if need_grad:
res[0] = bm.vector_grad(f, argnums=0)
return [tuple(res)], state
def step(self, *args, **kwargs):
# parse grad function and individual arguments
parses, state = self._get_g_grad(self.g, allow_raise=False, need_grad=True)
if not state:
parses2 = self._get_g_grad(self.f, allow_raise=True, need_grad=False)
if len(parses2) != len(parses):
raise ValueError(f'"f" and "g" should defined with JointEq both, and should '
f'keep the same structure.')
parses = [a[:1] + b[1:] for a, b in zip(parses, parses2)]
# input arguments
all_args = format_args(args, kwargs, self.arg_names)
dt = all_args.pop(DT, self.dt)
# drift values
drifts = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(drifts, (bm.ndarray, jnp.ndarray)):
raise ValueError('Drift values must be a tensor when there '
'is only one variable in the equation.')
drifts = {self.variables[0]: drifts}
else:
if not isinstance(drifts, (tuple, list)):
raise ValueError('Drift values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
drifts = {var: drifts[i] for i, var in enumerate(self.variables)}
# diffusion values
diffusions = self.g(**all_args)
if len(self.variables) == 1:
if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
raise ValueError('Diffusion values must be a tensor when there '
'is only one variable in the equation.')
diffusions = {self.variables[0]: diffusions}
else:
if not isinstance(diffusions, (tuple, list)):
raise ValueError('Diffusion values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)}
if self.wiener_type == constants.VECTOR_WIENER:
for key, val in diffusions.items():
if val is not None and jnp.ndim(val) == 0:
raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple "
f"dimensional diffusion value. But we got a scale value for "
f"variable {key}.")
# derivative of diffusion parts
all_dg = {}
for i, key in enumerate(self.variables):
f_dg, vars_, pars_ = parses[i]
vps = vars_ + pars_
all_dg[key] = f_dg(all_args[vps[0]], **{arg: all_args[arg] for arg in vps[1:] if arg in all_args})
# integral results
integrals = []
for i, key in enumerate(self.variables):
integral = all_args[key] + drifts[key] * dt
if diffusions[key] is not None:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.VECTOR_WIENER:
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
integral += diffusions[key] * noise
noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2
diffusion = diffusions[key] * all_dg[key] / 2 * noise_p2
diffusion = jnp.sum(diffusion, axis=-1) if self.wiener_type == constants.VECTOR_WIENER else diffusion
integral += diffusion
integrals.append(integral)
return integrals if len(self.variables) > 1 else integrals[0]
register_sde_integrator('milstein', Milstein)
[docs]
class MilsteinGradFree(SDEIntegrator):
r"""Derivative-free Milstein method for Ito or Stratonovich integrals.
The following implementation approximates the frist derivative of :math:`g` thanks to a Runge-Kutta approach.
For the Ito integral, the derivative-free Milstein method is given by
.. math::
\begin{aligned}
Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left[\left(\Delta W_{n}\right)^{2}-h\right] \\
\bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\
\Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
For the Stratonovich integral, the derivative-free Milstein method is given by
.. math::
\begin{aligned}
Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left(\Delta W_{n}\right)^{2} \\
\bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\
\Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1)
\end{aligned}
"""
def __init__(
self,
f: Callable,
g: Callable,
dt: float = None,
name: str = None,
show_code=False,
var_type: str = None,
intg_type: str = None,
wiener_type: str = None,
state_delays: Dict[str, bm.AbstractDelay] = None,
):
super(MilsteinGradFree, self).__init__(f=f,
g=g,
dt=dt,
name=name,
var_type=var_type,
intg_type=intg_type,
wiener_type=wiener_type,
state_delays=state_delays)
self.set_integral(self.step)
def step(self, *args, **kwargs):
# input arguments
all_args = format_args(args, kwargs, self.arg_names)
dt = all_args.pop(DT, self.dt)
# drift values
drifts = self.f(**all_args)
if len(self.variables) == 1:
if not isinstance(drifts, (bm.ndarray, jnp.ndarray)):
raise ValueError('Drift values must be a tensor when there '
'is only one variable in the equation.')
drifts = {self.variables[0]: drifts}
else:
if not isinstance(drifts, (tuple, list)):
raise ValueError('Drift values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
drifts = {var: drifts[i] for i, var in enumerate(self.variables)}
# diffusion values
diffusions = self.g(**all_args)
if len(self.variables) == 1:
if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
raise ValueError('Diffusion values must be a tensor when there '
'is only one variable in the equation.')
diffusions = {self.variables[0]: diffusions}
else:
if not isinstance(diffusions, (tuple, list)):
raise ValueError('Diffusion values must be a list/tuple of tensors '
'when there are multiple variables in the equation.')
diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)}
if self.wiener_type == constants.VECTOR_WIENER:
for key, val in diffusions.items():
if val is not None and jnp.ndim(val) == 0:
raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple "
f"dimensional diffusion value. But we got a scale value for "
f"variable {key}.")
# intermediate results
y_bars = {k: v for k, v in all_args.items()}
for key in self.variables:
bar = all_args[key] + drifts[key] * dt
if diffusions[key] is not None:
bar += diffusions[key] * jnp.sqrt(dt)
y_bars[key] = bar
diffusion_bars = self.g(**y_bars)
if len(self.variables) == 1:
diffusion_bars = {self.variables[0]: diffusion_bars}
else:
diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)}
# integral results
integrals = []
for i, key in enumerate(self.variables):
integral = all_args[key] + drifts[key] * dt
if diffusions[key] is not None:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.VECTOR_WIENER:
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
integral += diffusions[key] * noise
noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2
minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += minus * jnp.sum(noise_p2, axis=-1)
else:
integral += minus * noise_p2
integrals.append(integral)
return integrals if len(self.variables) > 1 else integrals[0]
register_sde_integrator('milstein2', MilsteinGradFree)
register_sde_integrator('milstein_grad_free', MilsteinGradFree)
[docs]
class ExponentialEuler(SDEIntegrator):
r"""First order, explicit exponential Euler method.
For a SDE equation of the form
.. math::
d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \quad y(0)=y_{0}
its schema is given by [1]_
.. math::
y_{n+1} & =e^{\Delta t A}(y_{n}+ g(y_n)\Delta W_{n})+\varphi(\Delta t A) F(y_{n}) \Delta t \\
&= y_n + \Delta t \varphi(\Delta t A) f(y) + e^{\Delta t A}g(y_n)\Delta W_{n}
where :math:`\varphi(z)=\frac{e^{z}-1}{z}`.
References::
.. [1] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic
differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016).
See Also::
Euler, Heun, Milstein
"""
def __init__(
self,
f: Callable,
g: Callable,
dt: float = None,
name: str = None,
show_code: bool = False,
var_type: str = None,
intg_type: str = None,
wiener_type: str = None,
dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
state_delays: Dict[str, bm.AbstractDelay] = None
):
super(ExponentialEuler, self).__init__(f=f,
g=g,
dt=dt,
show_code=show_code,
name=name,
var_type=var_type,
intg_type=intg_type,
wiener_type=wiener_type,
state_delays=state_delays)
if self.intg_type == constants.STRA_SDE:
raise NotImplementedError(
f'{self.__class__.__name__} does not support integral type of {constants.STRA_SDE}. '
f'It only supports {constants.ITO_SDE} now. ')
# build the integrator
self.integral = self.build()
def build(self):
parses = self._build_integrator(self.f)
all_vps = self.variables + self.parameters
def integral_func(*args, **kwargs):
# format arguments
params_in = bm.Collector()
for i, arg in enumerate(args):
params_in[all_vps[i]] = arg
params_in.update(kwargs)
dt = params_in.pop(constants.DT, self.dt)
# diffusion part
diffusions = self.g(**params_in)
# call integrals
results = []
params_in[constants.DT] = dt
for i, parse in enumerate(parses):
f_integral, vars_, pars_ = parse
vps = vars_ + pars_ + [constants.DT]
# integral of the drift part
r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in})
if isinstance(diffusions, (tuple, list)):
diffusion = diffusions[i]
else:
assert len(parses) == 1
diffusion = diffusions
# diffusion part
shape = jnp.shape(params_in[vps[0]])
if diffusion is not None:
diffusion = bm.as_jax(diffusion)
if self.wiener_type == constants.VECTOR_WIENER:
noise_shape = jnp.shape(diffusion)
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
diffusion = jnp.sum(diffusion * bm.random.randn(*shape), axis=-1)
else:
diffusion = diffusion * bm.random.randn(*shape)
r += diffusion * jnp.sqrt(params_in[constants.DT])
# final result
results.append(r)
return results if len(self.variables) > 1 else results[0]
return integral_func
def _build_integrator(self, f):
if isinstance(f, joint_eq.JointEq):
results = []
for sub_eq in f.eqs:
results.extend(self._build_integrator(sub_eq))
return results
else:
vars, pars, _ = utils.get_args(f)
if len(vars) != 1:
raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__,
vars=str(vars), eq=str(f)))
value_and_grad = bm.vector_grad(f, argnums=0, return_value=True)
# integration function
def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop('dt', self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
phi = bm.as_jax(bm.exprel(dt * linear))
return args[0] + dt * phi * derivative
return [(integral, vars, pars), ]
register_sde_integrator('exponential_euler', ExponentialEuler)
register_sde_integrator('exp_euler', ExponentialEuler)
register_sde_integrator('exp_euler_auto', ExponentialEuler)
register_sde_integrator('exp_auto', ExponentialEuler)