Source code for brainpy._src.dynold.neurons.fractional_models

# -*- coding: utf-8 -*-

from typing import Union, Sequence, Callable

import jax.numpy as jnp

import brainpy.math as bm
from brainpy._src.context import share
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy._src.integrators.fde import CaputoL1Schema
from brainpy._src.integrators.fde import GLShortMemory
from brainpy._src.integrators.joint_eq import JointEq
from brainpy.check import is_float, is_integer, is_initializer
from brainpy.types import Shape, ArrayType

__all__ = [
  'FractionalNeuron',
  'FractionalFHR',
  'FractionalIzhikevich',
]


[docs] class FractionalNeuron(NeuDyn): """Fractional-order neuron model.""" pass
[docs] class FractionalFHR(FractionalNeuron): r"""The fractional-order FH-R model [1]_. FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article), which is the modification of the classical FHN neuron model. The fractional-order FH-R model is described as .. math:: \begin{array}{rcl} \frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\ \frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\ \frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y), \end{array} where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable and slow modulation of the current respectively. :math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha` is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`. :math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters. The system reduces to the original classical order system when :math:`\alpha=1`. :math:`\mu` indicates a small parameter that determines the pace of the slow system variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator in the phase plane where :math:`\delta` is a small parameter. :math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale. It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`, it causes longer intervals between two burstings, however there exists :math:`a` relatively fixed time of bursting duration. With the increasing of :math:`a`, the interburst intervals become shorter and periodic bursting changes to tonic spiking. Examples -------- - [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html) Parameters ---------- size: int, sequence of int The size of the neuron group. alpha: float, tensor The fractional order. num_memory: int The total number of the short memory. References ---------- .. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4 """
[docs] def __init__( self, size: Shape, alpha: Union[float, Sequence[float]], num_memory: int = 1000, a: Union[float, ArrayType, Initializer, Callable] = 0.7, b: Union[float, ArrayType, Initializer, Callable] = 0.8, c: Union[float, ArrayType, Initializer, Callable] = -0.775, d: Union[float, ArrayType, Initializer, Callable] = 1., delta: Union[float, ArrayType, Initializer, Callable] = 0.08, mu: Union[float, ArrayType, Initializer, Callable] = 0.0001, Vth: Union[float, ArrayType, Initializer, Callable] = 1.8, V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.5), w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), y_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), input_var: bool = True, name: str = None, keep_size: bool = False, ): super(FractionalFHR, self).__init__(size, keep_size=keep_size, name=name) assert self.mode.is_one_of(bm.NonBatchingMode, ) # fractional order self.alpha = alpha is_integer(num_memory, 'num_memory', allow_none=False) # parameters self.a = parameter(a, self.varshape, allow_none=False) self.b = parameter(b, self.varshape, allow_none=False) self.c = parameter(c, self.varshape, allow_none=False) self.d = parameter(d, self.varshape, allow_none=False) self.mu = parameter(mu, self.varshape, allow_none=False) self.Vth = parameter(Vth, self.varshape, allow_none=False) self.delta = parameter(delta, self.varshape, allow_none=False) self.input_var = input_var # initializers is_initializer(V_initializer, 'V_initializer', allow_none=False) is_initializer(w_initializer, 'w_initializer', allow_none=False) is_initializer(y_initializer, 'y_initializer', allow_none=False) self._V_initializer = V_initializer self._w_initializer = w_initializer self._y_initializer = y_initializer # variables self.V = bm.Variable(parameter(V_initializer, self.varshape)) self.w = bm.Variable(parameter(w_initializer, self.varshape)) self.y = bm.Variable(parameter(y_initializer, self.varshape)) self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) if self.input_var: self.input = bm.Variable(jnp.zeros(self.varshape)) # integral function self.integral = GLShortMemory(self.derivative, alpha=alpha, num_memory=num_memory, inits=[self.V, self.w, self.y])
def reset_state(self, batch_size=None): self.V.value = parameter(self._V_initializer, self.varshape) self.w.value = parameter(self._w_initializer, self.varshape) self.y.value = parameter(self._y_initializer, self.varshape) self.spike[:] = False if self.input_var: self.input[:] = 0 # integral function reset self.integral.reset([self.V, self.w, self.y]) def dV(self, V, t, w, y, I): return V - V ** 3 / 3 - w + y + I def dw(self, w, t, V): return self.delta * (self.a + V - self.b * w) def dy(self, y, t, V): return self.mu * (self.c - V - self.d * y) @property def derivative(self): return JointEq([self.dV, self.dw, self.dy]) def update(self, x=None): t = share.load('t') dt = share.load('dt') if self.input_var: if x is not None: self.input += x x = self.input.value else: x = 0. if x is None else x V, w, y = self.integral(self.V, self.w, self.y, t, I=x, dt=dt) self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w self.y.value = y return self.spike.value def clear_input(self): if self.input_var: self.input[:] = 0.
[docs] class FractionalIzhikevich(FractionalNeuron): r"""Fractional-order Izhikevich model [10]_. The fractional-order Izhikevich model is given by .. math:: \begin{aligned} &\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\ &\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u) \end{aligned} where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`. It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`. The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the recovery variable that corresponds to the activation of K+ ionic current and inactivation of Na+ ionic current. The parameters :math:`f, g, h` are fixed constants (should not be changed) such that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance :math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied) dc stimulus current to the system. When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables are rest as follow: .. math:: \text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l} v \leftarrow c \\ u \leftarrow u+d \end{array}\right. we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the recovery variable :math:`u` are reset according to the above condition. Examples -------- - [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html) References ---------- .. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and bursting patterns of fractional-order Izhikevich model." Communications in Nonlinear Science and Numerical Simulation 56 (2018): 161-176. """
[docs] def __init__( self, size: Shape, alpha: Union[float, Sequence[float]], num_memory: int, a: Union[float, ArrayType, Initializer, Callable] = 0.02, b: Union[float, ArrayType, Initializer, Callable] = 0.20, c: Union[float, ArrayType, Initializer, Callable] = -65., d: Union[float, ArrayType, Initializer, Callable] = 8., f: Union[float, ArrayType, Initializer, Callable] = 0.04, g: Union[float, ArrayType, Initializer, Callable] = 5., h: Union[float, ArrayType, Initializer, Callable] = 140., R: Union[float, ArrayType, Initializer, Callable] = 1., tau: Union[float, ArrayType, Initializer, Callable] = 1., V_th: Union[float, ArrayType, Initializer, Callable] = 30., V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.), u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.20 * -65.), keep_size: bool = False, input_var: bool = True, name: str = None ): # initialization super(FractionalIzhikevich, self).__init__(size=size, keep_size=keep_size, name=name) assert self.mode.is_a(bm.NonBatchingMode) # params self.alpha = alpha is_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True) self.a = parameter(a, self.varshape, allow_none=False) self.b = parameter(b, self.varshape, allow_none=False) self.c = parameter(c, self.varshape, allow_none=False) self.d = parameter(d, self.varshape, allow_none=False) self.f = parameter(f, self.varshape, allow_none=False) self.g = parameter(g, self.varshape, allow_none=False) self.h = parameter(h, self.varshape, allow_none=False) self.tau = parameter(tau, self.varshape, allow_none=False) self.R = parameter(R, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) self.input_var = input_var # initializers is_initializer(V_initializer, 'V_initializer', allow_none=False) is_initializer(u_initializer, 'u_initializer', allow_none=False) self._V_initializer = V_initializer self._u_initializer = u_initializer # variables self.V = bm.Variable(parameter(V_initializer, self.varshape)) self.u = bm.Variable(parameter(u_initializer, self.varshape)) self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) if self.input_var: self.input = bm.Variable(jnp.zeros(self.varshape)) # functions is_integer(num_memory, 'num_memory', allow_none=False) self.integral = CaputoL1Schema(f=self.derivative, alpha=alpha, num_memory=num_memory, inits=[self.V, self.u])
def reset_state(self, batch_size=None): self.V.value = parameter(self._V_initializer, self.varshape) self.u.value = parameter(self._u_initializer, self.varshape) self.spike[:] = False if self.input_var: self.input[:] = 0 # integral function reset self.integral.reset([self.V, self.u]) def dV(self, V, t, u, I_ext): dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext return dVdt / self.tau def du(self, u, t, V): dudt = self.a * (self.b * V - u) return dudt / self.tau @property def derivative(self): return JointEq(self.dV, self.du) def update(self, x=None): if self.input_var: if x is not None: self.input += x x = self.input.value else: x = 0. if x is None else x V, u = self.integral(self.V, self.u, t=share['t'], I_ext=x, dt=share['dt']) spikes = V >= self.V_th self.V.value = jnp.where(spikes, self.c, V) self.u.value = jnp.where(spikes, u + self.d, u) self.spike.value = spikes return spikes def clear_input(self): if self.input_var: self.input[:] = 0.