# -*- coding: utf-8 -*-
from typing import Union, Sequence, Callable
import jax.numpy as jnp
import brainpy.math as bm
from brainpy._src.context import share
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy._src.integrators.fde import CaputoL1Schema
from brainpy._src.integrators.fde import GLShortMemory
from brainpy._src.integrators.joint_eq import JointEq
from brainpy.check import is_float, is_integer, is_initializer
from brainpy.types import Shape, ArrayType
__all__ = [
'FractionalNeuron',
'FractionalFHR',
'FractionalIzhikevich',
]
[docs]
class FractionalNeuron(NeuDyn):
"""Fractional-order neuron model."""
pass
[docs]
class FractionalFHR(FractionalNeuron):
r"""The fractional-order FH-R model [1]_.
FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article),
which is the modification of the classical FHN neuron model. The fractional-order
FH-R model is described as
.. math::
\begin{array}{rcl}
\frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\
\frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\
\frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y),
\end{array}
where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable
and slow modulation of the current respectively.
:math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha`
is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`.
:math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters.
The system reduces to the original classical order system when :math:`\alpha=1`.
:math:`\mu` indicates a small parameter that determines the pace of the slow system
variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator
in the phase plane where :math:`\delta` is a small parameter.
:math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale.
It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed
value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the
parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`,
it causes longer intervals between two burstings, however there exists :math:`a`
relatively fixed time of bursting duration. With the increasing of :math:`a`, the
interburst intervals become shorter and periodic bursting changes to tonic spiking.
Examples
--------
- [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html)
Parameters
----------
size: int, sequence of int
The size of the neuron group.
alpha: float, tensor
The fractional order.
num_memory: int
The total number of the short memory.
References
----------
.. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4
"""
[docs]
def __init__(
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_memory: int = 1000,
a: Union[float, ArrayType, Initializer, Callable] = 0.7,
b: Union[float, ArrayType, Initializer, Callable] = 0.8,
c: Union[float, ArrayType, Initializer, Callable] = -0.775,
d: Union[float, ArrayType, Initializer, Callable] = 1.,
delta: Union[float, ArrayType, Initializer, Callable] = 0.08,
mu: Union[float, ArrayType, Initializer, Callable] = 0.0001,
Vth: Union[float, ArrayType, Initializer, Callable] = 1.8,
V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.5),
w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
y_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
input_var: bool = True,
name: str = None,
keep_size: bool = False,
):
super(FractionalFHR, self).__init__(size, keep_size=keep_size, name=name)
assert self.mode.is_one_of(bm.NonBatchingMode, )
# fractional order
self.alpha = alpha
is_integer(num_memory, 'num_memory', allow_none=False)
# parameters
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.mu = parameter(mu, self.varshape, allow_none=False)
self.Vth = parameter(Vth, self.varshape, allow_none=False)
self.delta = parameter(delta, self.varshape, allow_none=False)
self.input_var = input_var
# initializers
is_initializer(V_initializer, 'V_initializer', allow_none=False)
is_initializer(w_initializer, 'w_initializer', allow_none=False)
is_initializer(y_initializer, 'y_initializer', allow_none=False)
self._V_initializer = V_initializer
self._w_initializer = w_initializer
self._y_initializer = y_initializer
# variables
self.V = bm.Variable(parameter(V_initializer, self.varshape))
self.w = bm.Variable(parameter(w_initializer, self.varshape))
self.y = bm.Variable(parameter(y_initializer, self.varshape))
self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool))
if self.input_var:
self.input = bm.Variable(jnp.zeros(self.varshape))
# integral function
self.integral = GLShortMemory(self.derivative,
alpha=alpha,
num_memory=num_memory,
inits=[self.V, self.w, self.y])
def reset_state(self, batch_size=None):
self.V.value = parameter(self._V_initializer, self.varshape)
self.w.value = parameter(self._w_initializer, self.varshape)
self.y.value = parameter(self._y_initializer, self.varshape)
self.spike[:] = False
if self.input_var:
self.input[:] = 0
# integral function reset
self.integral.reset([self.V, self.w, self.y])
def dV(self, V, t, w, y, I):
return V - V ** 3 / 3 - w + y + I
def dw(self, w, t, V):
return self.delta * (self.a + V - self.b * w)
def dy(self, y, t, V):
return self.mu * (self.c - V - self.d * y)
@property
def derivative(self):
return JointEq([self.dV, self.dw, self.dy])
def update(self, x=None):
t = share.load('t')
dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
V, w, y = self.integral(self.V, self.w, self.y, t, I=x, dt=dt)
self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth)
self.V.value = V
self.w.value = w
self.y.value = y
return self.spike.value
def clear_input(self):
if self.input_var:
self.input[:] = 0.
[docs]
class FractionalIzhikevich(FractionalNeuron):
r"""Fractional-order Izhikevich model [10]_.
The fractional-order Izhikevich model is given by
.. math::
\begin{aligned}
&\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\
&\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u)
\end{aligned}
where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`.
It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`.
The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV
corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the
recovery variable that corresponds to the activation of K+ ionic current and
inactivation of Na+ ionic current.
The parameters :math:`f, g, h` are fixed constants (should not be changed) such
that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are
dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance
:math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied)
dc stimulus current to the system.
When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables
are rest as follow:
.. math::
\text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l}
v \leftarrow c \\
u \leftarrow u+d
\end{array}\right.
we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed
in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the
recovery variable :math:`u` are reset according to the above condition.
Examples
--------
- [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html)
References
----------
.. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and
bursting patterns of fractional-order Izhikevich model." Communications
in Nonlinear Science and Numerical Simulation 56 (2018): 161-176.
"""
[docs]
def __init__(
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_memory: int,
a: Union[float, ArrayType, Initializer, Callable] = 0.02,
b: Union[float, ArrayType, Initializer, Callable] = 0.20,
c: Union[float, ArrayType, Initializer, Callable] = -65.,
d: Union[float, ArrayType, Initializer, Callable] = 8.,
f: Union[float, ArrayType, Initializer, Callable] = 0.04,
g: Union[float, ArrayType, Initializer, Callable] = 5.,
h: Union[float, ArrayType, Initializer, Callable] = 140.,
R: Union[float, ArrayType, Initializer, Callable] = 1.,
tau: Union[float, ArrayType, Initializer, Callable] = 1.,
V_th: Union[float, ArrayType, Initializer, Callable] = 30.,
V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.),
u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.20 * -65.),
keep_size: bool = False,
input_var: bool = True,
name: str = None
):
# initialization
super(FractionalIzhikevich, self).__init__(size=size, keep_size=keep_size, name=name)
assert self.mode.is_a(bm.NonBatchingMode)
# params
self.alpha = alpha
is_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.c = parameter(c, self.varshape, allow_none=False)
self.d = parameter(d, self.varshape, allow_none=False)
self.f = parameter(f, self.varshape, allow_none=False)
self.g = parameter(g, self.varshape, allow_none=False)
self.h = parameter(h, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.input_var = input_var
# initializers
is_initializer(V_initializer, 'V_initializer', allow_none=False)
is_initializer(u_initializer, 'u_initializer', allow_none=False)
self._V_initializer = V_initializer
self._u_initializer = u_initializer
# variables
self.V = bm.Variable(parameter(V_initializer, self.varshape))
self.u = bm.Variable(parameter(u_initializer, self.varshape))
self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool))
if self.input_var:
self.input = bm.Variable(jnp.zeros(self.varshape))
# functions
is_integer(num_memory, 'num_memory', allow_none=False)
self.integral = CaputoL1Schema(f=self.derivative,
alpha=alpha,
num_memory=num_memory,
inits=[self.V, self.u])
def reset_state(self, batch_size=None):
self.V.value = parameter(self._V_initializer, self.varshape)
self.u.value = parameter(self._u_initializer, self.varshape)
self.spike[:] = False
if self.input_var:
self.input[:] = 0
# integral function reset
self.integral.reset([self.V, self.u])
def dV(self, V, t, u, I_ext):
dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext
return dVdt / self.tau
def du(self, u, t, V):
dudt = self.a * (self.b * V - u)
return dudt / self.tau
@property
def derivative(self):
return JointEq(self.dV, self.du)
def update(self, x=None):
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
V, u = self.integral(self.V, self.u, t=share['t'], I_ext=x, dt=share['dt'])
spikes = V >= self.V_th
self.V.value = jnp.where(spikes, self.c, V)
self.u.value = jnp.where(spikes, u + self.d, u)
self.spike.value = spikes
return spikes
def clear_input(self):
if self.input_var:
self.input[:] = 0.