# Source code for brainpy.dyn.neurons.noise_groups

# -*- coding: utf-8 -*-

from typing import Union, Callable

from brainpy import math as bm, initialize as init
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import Initializer
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, normal
from brainpy.types import Array, Shape

__all__ = [
'OUProcess',
]

[docs]class OUProcess(NeuGroup):
r"""The Ornsteinâ€“Uhlenbeck process.

The Ornsteinâ€“Uhlenbeck process :math:x_{t} is defined by the following
stochastic differential equation:

.. math::

\tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t}

where :math:\theta >0 and :math:\sigma >0 are parameters and :math:W_{t}
denotes the Wiener process.

Parameters
----------
size: int, sequence of int
The model size.
mean: Parameter
The noise mean value.
sigma: Parameter
The noise amplitude.
tau: Parameter
The decay time constant.
method: str
The numerical integration method for stochastic differential equation.
name: str
The model name.
"""

[docs]  def __init__(
self,
size: Shape,
mean: Union[float, Array, Initializer, Callable] = 0.,
sigma: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
method: str = 'exp_euler',
keep_size: bool = False,
mode: Mode = normal,
name: str = None,
):
super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode)

# parameters
self.mean = init.parameter(mean, self.varshape, allow_none=False)
self.sigma = init.parameter(sigma, self.varshape, allow_none=False)
self.tau = init.parameter(tau, self.varshape, allow_none=False)

# variables
self.x = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, mode)

# integral functions
self.integral = sdeint(f=self.df, g=self.dg, method=method)

def reset_state(self, batch_size=None):
self.x.value = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, batch_size)

def df(self, x, t):
return (self.mean - x) / self.tau

def dg(self, x, t):
return self.sigma

def update(self, tdi):
self.x.value = self.integral(self.x, tdi['t'], tdi['dt'])