# -*- coding: utf-8 -*-
from typing import Union, Dict, Callable, Optional
import jax
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One
from brainpy._src.dnn import linear
from brainpy._src.dyn import _docs
from brainpy._src.dyn import synapses
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.dynold.synouts import MgBlock, CUBA
from brainpy._src.initialize import Initializer
from brainpy.types import ArrayType
from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre
__all__ = [
'Delta',
'Exponential',
'DualExponential',
'Alpha',
'NMDA',
]
[docs]
class Delta(TwoEndConn):
r"""Voltage Jump Synapse Model, or alias of Delta Synapse Model.
**Model Descriptions**
.. math::
I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D)
where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
:math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
:math:`C` the set of neurons connected to the post-synaptic neuron,
:math:`D` the transmission delay of chemical synapses,
and :math:`\mathrm{STP}` the short-term plasticity effect.
For simplicity, the rise and decay phases of post-synaptic currents are
omitted in this model.
**Model Examples**
>>> import brainpy as bp
>>> from brainpy import synapses, neurons
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), g_max=5.)
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8)
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.xlim(40, 150)
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuDyn
The pre-synaptic neuron group.
post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
g_max: float, ArrayType, Initializer, Callable
The synaptic strength. Default is 1.
post_ref_key: str
Whether the post-synaptic group has refractory period.
"""
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
output: _SynOut = CUBA(target_var='V'),
stp: Optional[_SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[float, ArrayType, Initializer, Callable] = None,
post_ref_key: str = None,
name: str = None,
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
super().__init__(name=name,
pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
mode=mode)
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.post_ref_key = post_ref_key
if post_ref_key:
self.check_post_attrs(post_ref_key)
self.comp_method = comp_method
# connections and weights
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')
# register delay
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
# update sub-components
if self.stp is not None:
self.stp.update(pre_spike)
# synaptic values onto the post
if isinstance(self.conn, All2All):
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
if self.stp is not None:
syn_value = self.stp(syn_value)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
if self.stp is not None:
syn_value = self.stp(syn_value)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
if self.stp is not None:
syn_value = self.stp(pre_spike)
f = lambda s: bm.sparse.csrmv(
self.g_max, self.conn_mask[0], self.conn_mask[1], s,
shape=(self.pre.num, self.post.num), transpose=True
)
else:
syn_value = pre_spike
f = lambda s: bm.event.csrmv(
self.g_max, self.conn_mask[0], self.conn_mask[1], s,
shape=(self.pre.num, self.post.num), transpose=True
)
if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f)
post_vs = f(syn_value)
else:
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
if self.stp is not None:
syn_value = self.stp(syn_value)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
if self.post_ref_key:
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))
# update outputs
return self.output(post_vs)
[docs]
class Exponential(TwoEndConn):
r"""Exponential decay synapse model.
%s
**Model Examples**
- `(Brunel & Hakim, 1999) Fast Global Oscillation <https://brainpy-examples.readthedocs.io/en/latest/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html>`_
- `(Vreeswijk & Sompolinsky, 1996) E/I balanced network <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Vreeswijk_1996_EI_net.html>`_
- `(Brette, et, al., 2007) CUBA <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Brette_2007_CUBA.html>`_
- `(Tian, et al., 2020) E/I Net for fast response <https://brainpy-examples.readthedocs.io/en/latest/ei_nets/Tian_2020_EI_net_for_fast_response.html>`_
>>> import brainpy as bp
>>> from brainpy import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Exponential(neu1, neu2, bp.conn.All2All(),
>>> g_max=5., output=synouts.CUBA())
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau: float, ArrayType
The time constant of decay. [ms]
g_max: float, ArrayType, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
"""
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
output: Optional[_SynOut] = CUBA(),
stp: Optional[_SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
tau: Union[float, ArrayType] = 8.0,
method: str = 'exp_auto',
# other parameters
name: str = None,
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
super().__init__(pre=pre,
post=post,
conn=conn,
output=output,
stp=stp,
name=name,
mode=mode)
# parameters
self.stop_spike_gradient = stop_spike_gradient
# synapse dynamics
self.syn = synapses.Expon(post.varshape, tau=tau, method=method)
# Projection
if isinstance(conn, All2All):
self.comm = linear.AllToAll(pre.num, post.num, g_max)
elif isinstance(conn, One2One):
assert post.num == pre.num
self.comm = linear.OneToOne(pre.num, g_max)
else:
if comp_method == 'dense':
self.comm = linear.MaskedLinear(conn, g_max)
elif comp_method == 'sparse':
if self.stp is None:
self.comm = linear.EventCSRLinear(conn, g_max)
else:
self.comm = linear.CSRLinear(conn, g_max)
else:
raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')
# delay
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)
@property
def g(self):
return self.syn.g
@g.setter
def g(self, value):
self.syn.g = value
def update(self, pre_spike=None):
# delays
if pre_spike is None:
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
# update sub-components
self.output.update()
if self.stp is not None:
self.stp.update(pre_spike)
pre_spike = self.stp(pre_spike)
# post values
g = self.syn(self.comm(pre_spike))
# output
return self.output(g)
Exponential.__doc__ = Exponential.__doc__ % (_docs.exp_syn_doc,)
[docs]
class DualExponential(_TwoEndConnAlignPre):
r"""Dual exponential synapse model.
%s
**Model Examples**
>>> import brainpy as bp
>>> from brainpy import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.DualExponential(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuDyn
The pre-synaptic neuron group.
post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau_decay: float, ArrayArray, ndarray
The time constant of the synaptic decay phase. [ms]
tau_rise: float, ArrayArray, ndarray
The time constant of the synaptic rise phase. [ms]
g_max: float, ArrayType, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
"""
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
stp: Optional[_SynSTP] = None,
output: _SynOut = None, # CUBA(),
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
tau_decay: Union[float, ArrayType] = 10.0,
tau_rise: Union[float, ArrayType] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
A: Optional[Union[float, ArrayType, Callable]] = None,
method: str = 'exp_auto',
# other parameters
name: str = None,
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau_rise = tau_rise
self.tau_decay = tau_decay
if bm.size(self.tau_rise) != 1:
raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau_rise}')
if bm.size(self.tau_decay) != 1:
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau_decay}')
syn = synapses.DualExpon(pre.size,
pre.keep_size,
A=A,
mode=mode,
tau_decay=tau_decay,
tau_rise=tau_rise,
method=method, )
super().__init__(pre=pre,
post=post,
syn=syn,
conn=conn,
output=output,
stp=stp,
comp_method=comp_method,
g_max=g_max,
delay_step=delay_step,
name=name,
mode=mode)
self.check_post_attrs('input')
# copy the references
self.g = syn.g
self.h = syn.h
def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
DualExponential.__doc__ = DualExponential.__doc__ % (_docs.dual_exp_syn_doc,)
[docs]
class Alpha(_TwoEndConnAlignPre):
r"""Alpha synapse model.
%s
**Model Examples**
>>> import brainpy as bp
>>> from brainpy import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuDyn
The pre-synaptic neuron group.
post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau_decay: float, ArrayType
The time constant of the synaptic decay phase. [ms]
g_max: float, ArrayType, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
"""
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
output: _SynOut = None, # CUBA(),
stp: Optional[_SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
tau_decay: Union[float, ArrayType] = 10.0,
method: str = 'exp_auto',
# other parameters
name: str = None,
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau_decay = tau_decay
if bm.size(self.tau_decay) != 1:
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau_decay}')
syn = synapses.Alpha(pre.size,
pre.keep_size,
mode=mode,
tau_decay=tau_decay,
method=method)
super().__init__(pre=pre,
post=post,
syn=syn,
conn=conn,
comp_method=comp_method,
delay_step=delay_step,
g_max=g_max,
output=output,
stp=stp,
name=name,
mode=mode, )
self.check_post_attrs('input')
# copy the references
self.g = syn.g
self.h = syn.h
def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc,)
[docs]
class NMDA(_TwoEndConnAlignPre):
r"""NMDA synapse model.
**Model Descriptions**
The NMDA receptor is a glutamate receptor and ion channel found in neurons.
The NMDA receptor is one of three types of ionotropic glutamate receptors,
the other two being AMPA and kainate receptors.
The NMDA receptor mediated conductance depends on the postsynaptic voltage.
The voltage dependence is due to the blocking of the pore of the NMDA receptor
from the outside by a positively charged magnesium ion. The channel is
nearly completely blocked at resting potential, but the magnesium block is
relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
that are not blocked by magnesium can be fitted to
.. math::
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V}
\frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
usually 1 mM. Thus, the channel acts as a
"coincidence detector" and only once both of these conditions are met, the
channel opens and it allows positively charged ions (cations) to flow through
the cell membrane [2]_.
If we make the approximation that the magnesium block changes
instantaneously with voltage and is independent of the gating of the channel,
the net NMDA receptor-mediated synaptic current is given by
.. math::
I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.
Simultaneously, the kinetics of synaptic state :math:`g` is given by
.. math::
& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\
& \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k})
where the decay time of NMDA currents is usually taken to be
:math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms.
The NMDA receptor has been thought to be very important for controlling
synaptic plasticity and mediating learning and memory functions [3]_.
**Model Examples**
- `(Wang, 2002) Decision making spiking model <https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2002_decision_making_spiking.html>`_
>>> import brainpy as bp
>>> from brainpy import synapses, neurons
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.HH(1)
>>> neu2 = neurons.HH(1)
>>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All())
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuDyn
The pre-synaptic neuron group.
post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
g_max: float, ArrayType, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
tau_decay: float, ArrayType
The time constant of the synaptic decay phase. Default 100 [ms]
tau_rise: float, ArrayType
The time constant of the synaptic rise phase. Default 2 [ms]
a: float, ArrayType
Default 0.5 ms^-1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
References
----------
.. [1] Brunel N, Wang X J. Effects of neuromodulation in a
cortical network model of object working memory dominated
by recurrent inhibition[J].
Journal of computational neuroscience, 2001, 11(1): 63-85.
.. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
Nature 438, no. 7065 (2005): 185-192.
.. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
England journal of medicine, 361(3), p.302.
.. [4] https://en.wikipedia.org/wiki/NMDA_receptor
"""
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
output: _SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2),
stp: Optional[_SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 0.15,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
tau_decay: Union[float, ArrayType] = 100.,
a: Union[float, ArrayType] = 0.5,
tau_rise: Union[float, ArrayType] = 2.,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
stop_spike_gradient: bool = False,
):
# parameters
self.tau_decay = tau_decay
self.tau_rise = tau_rise
self.a = a
if bm.size(a) != 1:
raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}')
if bm.size(tau_decay) != 1:
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}')
if bm.size(tau_rise) != 1:
raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}')
self.comp_method = comp_method
self.stop_spike_gradient = stop_spike_gradient
syn = synapses.NMDA(pre.size,
pre.keep_size,
mode=mode,
a=a,
tau_decay=tau_decay,
tau_rise=tau_rise,
method=method, )
super().__init__(pre=pre,
post=post,
syn=syn,
conn=conn,
output=output,
stp=stp,
comp_method=comp_method,
g_max=g_max,
delay_step=delay_step,
name=name,
mode=mode)
# copy the references
self.g = syn.g
self.x = syn.x
def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)