Source code for brainpy._src.dyn.others.noise

from typing import Union, Callable

import jax.numpy as jnp

import brainpy.math as bm
from brainpy._src.context import share
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import variable_, parameter
from brainpy._src.integrators.sde.generic import sdeint
from brainpy.types import Shape, ArrayType

__all__ = [
  'OUProcess',
]


[docs] class OUProcess(NeuDyn): r"""The Ornstein–Uhlenbeck process. The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following stochastic differential equation: .. math:: \tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t} where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}` denotes the Wiener process. Parameters ---------- size: int, sequence of int The model size. mean: Parameter The noise mean value. sigma: Parameter The noise amplitude. tau: Parameter The decay time constant. method: str The numerical integration method for stochastic differential equation. name: str The model name. """ def __init__( self, size: Shape, mean: Union[float, ArrayType, Callable] = 0., sigma: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., method: str = 'exp_euler', keep_size: bool = False, mode: bm.Mode = None, name: str = None, ): super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) # parameters self.mean = parameter(mean, self.varshape, allow_none=False) self.sigma = parameter(sigma, self.varshape, allow_none=False) self.tau = parameter(tau, self.varshape, allow_none=False) # variables self.reset_state(self.mode) # integral functions self.integral = sdeint(f=self.df, g=self.dg, method=method) def reset_state(self, batch_or_mode=None, **kwargs): self.x = variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_or_mode) def df(self, x, t): return (self.mean - x) / self.tau def dg(self, x, t): return self.sigma
[docs] def update(self): t = share.load('t') dt = share.load('dt') self.x.value = self.integral(self.x, t, dt) return self.x.value