# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Sequence, Callable, Optional
from brainpy import math as bm
from brainpy.context import share
from brainpy.dyn import _docs
from brainpy.dyn.base import SynDyn
from brainpy.initialize import parameter
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode.generic import odeint
from brainpy.mixin import AlignPost, ReturnInfo
from brainpy.types import ArrayType
__all__ = [
'Expon',
'DualExpon',
'DualExponV2',
'Alpha',
'NMDA',
'STD',
'STP',
]
[docs]
class Expon(SynDyn, AlignPost):
r"""Exponential decay synapse model.
%s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
.. code-block:: python
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class ExponSparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.Expon.desc(pre.num, tau=tau),
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
out=bp.dyn.COBA(E=E),
post=post,
)
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.))
self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Constant(-60.))
self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau=5., E=E)
def update(self):
self.pre()
self.syn()
self.post()
# monitor the following variables
conductance = self.syn.proj.refs['syn'].g
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
Moreover, it can also be used with interface ``ProjAlignPostMg2``:
.. code-block:: python
class ExponSparseCOBAPost(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPostMg2(
pre=pre,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
syn=bp.dyn.Expon.desc(post.num, tau=tau),
out=bp.dyn.COBA.desc(E=E),
post=post,
)
Parameters
----------
tau : float
The time constant of decay. [ms]
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
tau: Union[float, ArrayType, Callable] = 8.0,
):
super().__init__(
name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding
)
# parameters
self.tau = self.init_param(tau)
# function
self.integral = odeint(self.derivative, method=method)
self._current = None
self.reset_state(self.mode)
def derivative(self, g, t):
return -g / self.tau
def reset_state(self, batch_or_mode=None, **kwargs):
self.g = self.init_variable(bm.zeros, batch_or_mode)
[docs]
def update(self, x=None):
self.g.value = self.integral(self.g.value, share['t'], share['dt'])
if x is not None:
self.add_current(x)
return self.g.value
def add_current(self, x):
self.g.value += x
def return_info(self):
return self.g
Expon.__doc__ = Expon.__doc__ % (_docs.exp_syn_doc, _docs.pneu_doc,)
def _format_dual_exp_A(self, A):
A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding)
if A is None:
# The peak normalizer ``A = tau_decay / (tau_decay - tau_rise) * ...`` divides
# by zero when ``tau_rise == tau_decay``. For ``DualExponV2`` (the only consumer
# of this auto-normalizer that stores ``A`` directly and returns
# ``A * (g_decay - g_rise)``) the two gates collapse to identical trajectories,
# so ``g_decay - g_rise`` is identically zero and no finite ``A`` can recover a
# non-zero waveform. Fail loudly with an actionable message rather than emitting
# a ZeroDivisionError / silent NaN. ``DualExpon`` does not reach this branch for
# the equal-tau case: it computes its ``a`` coefficient directly (see below).
if bm.any(self.tau_rise == self.tau_decay):
raise ValueError(
'The dual-exponential peak normalizer "A" is undefined when '
'"tau_rise == tau_decay" (the rise and decay gates collapse to the '
'same trajectory). Use brainpy.dyn.Alpha for a single-time-constant '
'alpha synapse, or pass an explicit "A".'
)
A = (self.tau_decay / (self.tau_decay - self.tau_rise) *
bm.float_power(self.tau_rise / self.tau_decay, self.tau_rise / (self.tau_rise - self.tau_decay)))
return A
def _dual_exp_a(self, A):
r"""Compute the input-scaling coefficient ``a`` for :class:`DualExpon`.
The conductance jump per spike is ``a``. With the auto peak normalizer
(``A is None``) the closed form simplifies to
.. math::
a = \frac{1}{\tau_{rise}}
\left(\frac{\tau_{rise}}{\tau_{decay}}\right)^{\tau_{rise}/(\tau_{rise}-\tau_{decay})}
which is finite even when :math:`\tau_{rise} = \tau_{decay}` (the
dual-exponential degenerates to the normalized alpha function). In that limit
L'Hôpital gives :math:`a = e/\tau`, evaluated element-wise so heterogeneous
time constants are supported. An explicitly supplied ``A`` is honoured as
``a = (tau_decay - tau_rise) / (tau_rise * tau_decay) * A``.
"""
A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding)
if A is not None:
return (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A
equal = self.tau_rise == self.tau_decay
ratio = self.tau_rise / self.tau_decay
# ``where(equal, ...)`` on the exponent avoids a 0/0 in the unused branch so the
# gradient/value stays finite; the equal-tau entries take the L'Hôpital value e/tau.
exponent = self.tau_rise / bm.where(equal, bm.ones_like(self.tau_decay),
self.tau_rise - self.tau_decay)
a_general = bm.float_power(ratio, exponent) / self.tau_rise
a_limit = bm.exp(bm.ones_like(self.tau_rise)) / self.tau_rise
return bm.where(equal, a_limit, a_general)
[docs]
class DualExpon(SynDyn):
r"""Dual exponential synapse model.
%s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
.. code-block:: python
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class DualExpSparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.DualExpon.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise),
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
out=bp.dyn.COBA(E=E),
post=post,
)
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.))
self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Constant(-60.))
self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1.,
tau_decay=5., tau_rise=1., E=E)
def update(self):
self.pre()
self.syn()
self.post()
# monitor the following variables
conductance = self.syn.proj.refs['syn'].g
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
indices = np.arange(1000) # 100 ms, dt= 0.1 ms
net = SimpleNet(DualExpSparseCOBA, E=0.)
conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True)
ts = indices * bm.get_dt()
fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4)
fig.add_subplot(gs[0, 0])
plt.plot(ts, conductances)
plt.title('Syn conductance')
fig.add_subplot(gs[0, 1])
plt.plot(ts, currents)
plt.title('Syn current')
fig.add_subplot(gs[0, 2])
plt.plot(ts, potentials)
plt.title('Post V')
plt.show()
See Also
--------
DualExponV2
.. note::
The implementation of this model can only be used in ``AlignPre`` projections.
One the contrary, to seek the ``AlignPost`` projection, please use ``DualExponV2``.
Parameters
----------
%s
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
A: Optional[Union[float, ArrayType, Callable]] = None,
):
super().__init__(name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding)
# parameters
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
# Compute the conductance-jump coefficient ``a`` directly. This avoids the
# ``(tau_decay - tau_rise)`` cancellation that produced a ZeroDivisionError /
# NaN for equal time constants, and supports the alpha-function limit
# ``tau_rise == tau_decay`` (a = e/tau) element-wise.
self.a = _dual_exp_a(self, A)
# integrator
self.integral = odeint(JointEq(self.dg, self.dh), method=method)
self.reset_state(self.mode)
def reset_state(self, batch_or_mode=None, **kwargs):
self.h = self.init_variable(bm.zeros, batch_or_mode)
self.g = self.init_variable(bm.zeros, batch_or_mode)
def dh(self, h, t):
return -h / self.tau_rise
def dg(self, g, t, h):
return -g / self.tau_decay + h
[docs]
def update(self, x):
# x: the pre-synaptic spikes
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h.value = self.h.value + self.a * x
return self.g.value
def return_info(self):
return self.g
DualExpon.__doc__ = DualExpon.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args)
[docs]
class DualExponV2(SynDyn, AlignPost):
r"""Dual exponential synapse model.
%s
.. note::
Different from ``DualExpon``, this model can be used in both modes of ``AlignPre`` and ``AlignPost`` projections.
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
.. code-block:: python
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class DualExponV2SparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.DualExponV2.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise),
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
out=bp.dyn.COBA(E=E),
post=post,
)
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.))
self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Constant(-60.))
self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau_decay=5., tau_rise=1., E=E)
def update(self):
self.pre()
self.syn()
self.post()
# monitor the following variables
conductance = self.syn.proj.refs['syn'].g_rise
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
indices = np.arange(1000) # 100 ms, dt= 0.1 ms
net = SimpleNet(DualExponV2SparseCOBAPost, E=0.)
conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True)
ts = indices * bm.get_dt()
fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4)
fig.add_subplot(gs[0, 0])
plt.plot(ts, conductances)
plt.title('Syn conductance')
fig.add_subplot(gs[0, 1])
plt.plot(ts, currents)
plt.title('Syn current')
fig.add_subplot(gs[0, 2])
plt.plot(ts, potentials)
plt.title('Post V')
plt.show()
Moreover, it can also be used with interface ``ProjAlignPostMg2``:
.. code-block:: python
class DualExponV2SparseCOBAPost(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPostMg2(
pre=pre,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
syn=bp.dyn.DualExponV2.desc(post.num, tau_decay=tau_decay, tau_rise=tau_rise),
out=bp.dyn.COBA.desc(E=E),
post=post,
)
See Also
--------
DualExpon
Parameters
----------
%s
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
A: Optional[Union[float, ArrayType, Callable]] = None,
):
super().__init__(name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding)
# parameters
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
self.a = _format_dual_exp_A(self, A)
# integrator
self.integral = odeint(lambda g, t, tau: -g / tau, method=method)
self.reset_state(self.mode)
def reset_state(self, batch_or_mode=None, **kwargs):
self.g_rise = self.init_variable(bm.zeros, batch_or_mode)
self.g_decay = self.init_variable(bm.zeros, batch_or_mode)
[docs]
def update(self, x=None):
self.g_rise.value = self.integral(self.g_rise.value, share['t'], self.tau_rise, share['dt'])
self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt'])
if x is not None:
self.add_current(x)
return self.a * (self.g_decay - self.g_rise)
def add_current(self, inp):
self.g_rise += inp
self.g_decay += inp
def return_info(self):
return ReturnInfo(self.varshape, self.sharding, self.mode,
lambda shape: self.a * (self.g_decay - self.g_rise))
DualExponV2.__doc__ = DualExponV2.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args,)
[docs]
class Alpha(SynDyn):
r"""Alpha synapse model.
%s
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
.. code-block:: python
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class AlphaSparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.Alpha.desc(pre.num, tau_decay=tau_decay),
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
out=bp.dyn.COBA(E=E),
post=post,
)
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.))
self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Constant(-60.))
self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1.,
tau_decay=5., E=E)
def update(self):
self.pre()
self.syn()
self.post()
# monitor the following variables
conductance = self.syn.proj.refs['syn'].g
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
indices = np.arange(1000) # 100 ms, dt= 0.1 ms
net = SimpleNet(AlphaSparseCOBA, E=0.)
conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True)
ts = indices * bm.get_dt()
fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4)
fig.add_subplot(gs[0, 0])
plt.plot(ts, conductances)
plt.title('Syn conductance')
fig.add_subplot(gs[0, 1])
plt.plot(ts, currents)
plt.title('Syn current')
fig.add_subplot(gs[0, 2])
plt.plot(ts, potentials)
plt.title('Post V')
plt.show()
Parameters
----------
%s
tau_decay : float, ArrayType, Callable
The time constant [ms] of the synaptic decay phase.
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
):
super().__init__(
name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding
)
# parameters
self.tau_decay = self.init_param(tau_decay)
# integrator
self.integral = odeint(JointEq(self.dg, self.dh), method=method)
self.reset_state(self.mode)
def reset_state(self, batch_or_mode=None, **kwargs):
self.h = self.init_variable(bm.zeros, batch_or_mode)
self.g = self.init_variable(bm.zeros, batch_or_mode)
def dh(self, h, t):
return -h / self.tau_decay
def dg(self, g, t, h):
return -g / self.tau_decay + h / self.tau_decay
[docs]
def update(self, x):
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h.value = self.h.value + x
return self.g.value
def return_info(self):
return self.g
Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc, _docs.pneu_doc,)
[docs]
class NMDA(SynDyn):
r"""NMDA synapse model.
**Model Descriptions**
The NMDA receptor is a glutamate receptor and ion channel found in neurons.
The NMDA receptor is one of three types of ionotropic glutamate receptors,
the other two being AMPA and kainate receptors.
The NMDA receptor mediated conductance depends on the postsynaptic voltage.
The voltage dependence is due to the blocking of the pore of the NMDA receptor
from the outside by a positively charged magnesium ion. The channel is
nearly completely blocked at resting potential, but the magnesium block is
relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
that are not blocked by magnesium can be fitted to
.. math::
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V}
\frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
usually 1 mM. Thus, the channel acts as a
"coincidence detector" and only once both of these conditions are met, the
channel opens and it allows positively charged ions (cations) to flow through
the cell membrane [2]_.
If we make the approximation that the magnesium block changes
instantaneously with voltage and is independent of the gating of the channel,
the net NMDA receptor-mediated synaptic current is given by
.. math::
I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.
Simultaneously, the kinetics of synaptic state :math:`g` is given by
.. math::
& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\
& \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k})
where the decay time of NMDA currents is usually taken to be
:math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms.
The NMDA receptor has been thought to be very important for controlling
synaptic plasticity and mediating learning and memory functions [3]_.
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
.. code-block:: python
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
class NMDASparseCOBA(bp.Projection):
def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
super().__init__()
self.proj = bp.dyn.ProjAlignPreMg2(
pre=pre,
delay=delay,
syn=bp.dyn.NMDA.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise),
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
out=bp.dyn.COBA(E=E),
post=post,
)
class SimpleNet(bp.DynSysGroup):
def __init__(self, syn_cls, E=0.):
super().__init__()
self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.))
self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Constant(-60.))
self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1.,
tau_decay=5., tau_rise=1., E=E)
def update(self):
self.pre()
self.syn()
self.post()
# monitor the following variables
conductance = self.syn.proj.refs['syn'].g
current = self.post.sum_inputs(self.post.V)
return conductance, current, self.post.V
indices = np.arange(1000) # 100 ms, dt= 0.1 ms
net = SimpleNet(NMDASparseCOBA, E=0.)
conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True)
ts = indices * bm.get_dt()
fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4)
fig.add_subplot(gs[0, 0])
plt.plot(ts, conductances)
plt.title('Syn conductance')
fig.add_subplot(gs[0, 1])
plt.plot(ts, currents)
plt.title('Syn current')
fig.add_subplot(gs[0, 2])
plt.plot(ts, potentials)
plt.title('Post V')
plt.show()
.. [1] Brunel N, Wang X J. Effects of neuromodulation in a
cortical network model of object working memory dominated
by recurrent inhibition[J].
Journal of computational neuroscience, 2001, 11(1): 63-85.
.. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
Nature 438, no. 7065 (2005): 185-192.
.. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
England journal of medicine, 361(3), p.302.
.. [4] https://en.wikipedia.org/wiki/NMDA_receptor
Parameters
----------
tau_decay : float, ArrayType, Callable
The time constant of the synaptic decay phase. Default 100 [ms]
tau_rise : float, ArrayType, Callable
The time constant of the synaptic rise phase. Default 2 [ms]
a : float, ArrayType, Callable
Default 0.5 ms^-1.
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
a: Union[float, ArrayType, Callable] = 0.5,
tau_decay: Union[float, ArrayType, Callable] = 100.,
tau_rise: Union[float, ArrayType, Callable] = 2.,
):
super().__init__(name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding)
# parameters
self.tau_decay = self.init_param(tau_decay)
self.tau_rise = self.init_param(tau_rise)
self.a = self.init_param(a)
# integral
self.integral = odeint(method=method, f=JointEq(self.dg, self.dx))
self.reset_state(self.mode)
def dg(self, g, t, x):
return -g / self.tau_decay + self.a * x * (1 - g)
def dx(self, x, t):
return -x / self.tau_rise
def reset_state(self, batch_or_mode=None, **kwargs):
self.g = self.init_variable(bm.zeros, batch_or_mode)
self.x = self.init_variable(bm.zeros, batch_or_mode)
[docs]
def update(self, pre_spike):
t = share.load('t')
dt = share.load('dt')
self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt)
self.x.value = self.x.value + pre_spike
return self.g.value
def return_info(self):
return self.g
NMDA.__doc__ = NMDA.__doc__ % (_docs.pneu_doc,)
[docs]
class STD(SynDyn):
r"""Synaptic output with short-term depression.
%s
Parameters
----------
tau : float, ArrayType, Callable
The time constant of recovery of the synaptic vesicles.
U : float, ArrayType, Callable
The fraction of resources used per action potential.
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
tau: Union[float, ArrayType, Callable] = 200.,
U: Union[float, ArrayType, Callable] = 0.07,
):
super().__init__(name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding)
# parameters
self.tau = self.init_param(tau)
self.U = self.init_param(U)
# integral function
self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=method)
self.reset_state(self.mode)
def reset_state(self, batch_or_mode=None, **kwargs):
self.x = self.init_variable(bm.ones, batch_or_mode)
[docs]
def update(self, pre_spike):
t = share.load('t')
dt = share.load('dt')
x = self.integral(self.x.value, t, dt)
# --- original code:
# self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
# --- simplified code:
self.x.value = x - pre_spike * self.U * self.x
return self.x.value
def return_info(self):
return self.x
STD.__doc__ = STD.__doc__ % (_docs.std_doc, _docs.pneu_doc,)
[docs]
class STP(SynDyn):
r"""Synaptic output with short-term plasticity.
%s
Parameters
----------
tau_f : float, ArrayType, Callable
The time constant of short-term facilitation.
tau_d : float, ArrayType, Callable
The time constant of short-term depression.
U : float, ArrayType, Callable
The fraction of resources used per action potential.
%s
"""
def __init__(
self,
size: Union[int, Sequence[int]],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
# synapse parameters
U: Union[float, ArrayType, Callable] = 0.15,
tau_f: Union[float, ArrayType, Callable] = 1500.,
tau_d: Union[float, ArrayType, Callable] = 200.,
):
super().__init__(name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding)
# parameters
self.tau_f = self.init_param(tau_f)
self.tau_d = self.init_param(tau_d)
self.U = self.init_param(U)
self.method = method
# integral function
self.integral = odeint(self.derivative, method=self.method)
self.reset_state(self.mode)
def reset_state(self, batch_or_mode=None, **kwargs):
self.x = self.init_variable(bm.ones, batch_or_mode)
# Initialise ``u`` to the release probability ``U`` by broadcasting rather than
# ``Variable.fill_`` (which only accepts a scalar). This supports a per-neuron
# array ``U`` and batched modes alike.
self.u = self.init_variable(bm.ones, batch_or_mode)
self.u.value = self.u.value * self.U
@property
def derivative(self):
du = lambda u, t: -u / self.tau_f
dx = lambda x, t: (1 - x) / self.tau_d
return JointEq(du, dx)
[docs]
def update(self, pre_spike):
t = share.load('t')
dt = share.load('dt')
u, x = self.integral(self.u.value, self.x.value, t, dt)
# --- original code:
# if pre_spike.dtype == jax.numpy.bool_:
# u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
# x = bm.where(pre_spike, x - u * self.x, x)
# else:
# u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
# x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
# --- simplified code:
# Apply the discrete spike jumps to the just-integrated (decayed) locals
# ``u``/``x`` rather than the pre-decay ``self.u``/``self.x``. The ``x``
# jump uses the already-updated ``u``.
u = u + pre_spike * self.U * (1 - u)
x = x - pre_spike * u * x
self.x.value = x
self.u.value = u
return u * x
def return_info(self):
return ReturnInfo(self.varshape, self.sharding, self.mode,
lambda shape: self.u * self.x)
STP.__doc__ = STP.__doc__ % (_docs.stp_doc, _docs.pneu_doc,)