from typing import Optional, Union
from brainpy import math as bm, check
from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return)
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData)
from .utils import _get_return
__all__ = [
'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
'FullProjAlignPreSD', 'FullProjAlignPreDS',
]
def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None):
_syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}'
if not delay_cls.has_bef_update(_syn_id):
# delay
delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name)
# synapse
syn_cls = syn_desc()
# add to "after_updates"
delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
syn = delay_cls.get_bef_update(_syn_id).syn
return syn
class _AlignPreMg(DynamicalSystem):
def __init__(self, access, syn):
super().__init__()
self.access = access
self.syn = syn
def update(self, *args, **kwargs):
return self.syn(self.access())
def reset_state(self, *args, **kwargs):
pass
def align_pre1_add_bef_update(syn_desc, pre):
_syn_id = f'{syn_desc.identifier} // Delay'
if not pre.has_aft_update(_syn_id):
# "syn_cls" needs an instance of "ProjAutoDelay"
syn_cls: SupportAutoDelay = syn_desc()
delay_cls = init_delay_by_return(syn_cls.return_info())
# add to "after_updates"
pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls))
delay_cls: Delay = pre.get_aft_update(_syn_id).delay
syn = pre.get_aft_update(_syn_id).syn
return delay_cls, syn
class _AlignPre(DynamicalSystem):
def __init__(self, syn, delay=None):
super().__init__()
self.syn = syn
self.delay = delay
def update(self, x):
if self.delay is None:
return x >> self.syn
else:
return x >> self.syn >> self.delay
def reset_state(self, *args, **kwargs):
pass
[docs]
class FullProjAlignPreSDMg(Projection):
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
synapse states to the delay model, and finally computes the synaptic current.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
.. code-block:: python
class EINet(bp.DynSysGroup):
def __init__(self):
super().__init__()
ne, ni = 3200, 800
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I)
def update(self, inp):
self.E2E()
self.E2I()
self.I2E()
self.I2I()
self.E(inp)
self.I(inp)
return self.E.spike
model = EINet()
indices = bm.arange(1000)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)
Args:
pre: The pre-synaptic neuron group.
syn: The synaptic dynamics.
delay: The synaptic delay.
comm: The synaptic communication.
out: The synaptic output.
post: The post-synaptic neuron group.
name: str. The projection name.
mode: Mode. The computing mode.
"""
def __init__(
self,
pre: DynamicalSystem,
syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
post: DynamicalSystem,
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
# synaptic models
check.is_instance(pre, DynamicalSystem)
check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
self.comm = comm
# synapse and delay initialization
delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre)
delay_cls.register_entry(self.name, delay)
# output initialization
post.add_inp_fun(self.name, out, label=out_label)
# references
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = syn_cls
# unify the access
self.refs['comm'] = comm
[docs]
def update(self, x=None):
if x is None:
x = self.refs['delay'].at(self.name)
current = self.comm(x)
self.refs['out'].bind_cond(current)
return current
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])
[docs]
class FullProjAlignPreDSMg(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
.. code-block:: python
class EINet(bp.DynSysGroup):
def __init__(self):
super().__init__()
ne, ni = 3200, 800
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I)
def update(self, inp):
self.E2E()
self.E2I()
self.I2E()
self.I2I()
self.E(inp)
self.I(inp)
return self.E.spike
model = EINet()
indices = bm.arange(1000)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)
Args:
pre: The pre-synaptic neuron group.
delay: The synaptic delay.
syn: The synaptic dynamics.
comm: The synaptic communication.
out: The synaptic output.
post: The post-synaptic neuron group.
name: str. The projection name.
mode: Mode. The computing mode.
"""
def __init__(
self,
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: ParamDescriber[DynamicalSystem],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
post: DynamicalSystem,
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescriber[DynamicalSystem])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
self.comm = comm
# delay initialization
delay_cls = register_delay_by_return(pre)
# synapse initialization
syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name)
# output initialization
post.add_inp_fun(self.name, out, label=out_label)
# references
self.refs = dict()
# invisible to `self.nodes()`
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['syn'] = syn_cls
self.refs['out'] = out
# unify the access
self.refs['comm'] = comm
[docs]
def update(self):
x = _get_return(self.refs['syn'].return_info())
current = self.comm(x)
self.refs['out'].bind_cond(current)
return current
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])
[docs]
class FullProjAlignPreSD(Projection):
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
synapse states to the delay model, and finally computes the synaptic current.
Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
.. code-block:: python
class EINet(bp.DynSysGroup):
def __init__(self):
super().__init__()
ne, ni = 3200, 800
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I)
def update(self, inp):
self.E2E()
self.E2I()
self.I2E()
self.I2I()
self.E(inp)
self.I(inp)
return self.E.spike
model = EINet()
indices = bm.arange(1000)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)
Args:
pre: The pre-synaptic neuron group.
syn: The synaptic dynamics.
delay: The synaptic delay.
comm: The synaptic communication.
out: The synaptic output.
post: The post-synaptic neuron group.
name: str. The projection name.
mode: Mode. The computing mode.
"""
def __init__(
self,
pre: DynamicalSystem,
syn: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
post: DynamicalSystem,
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
# synaptic models
check.is_instance(pre, DynamicalSystem)
check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
self.comm = comm
# synapse and delay initialization
delay_cls = init_delay_by_return(syn.return_info())
delay_cls.register_entry(self.name, delay)
pre.add_aft_update(self.name, _AlignPre(syn, delay_cls))
# output initialization
post.add_inp_fun(self.name, out, label=out_label)
# references
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = syn
# unify the access
self.refs['comm'] = comm
[docs]
def update(self, x=None):
if x is None:
x = self.refs['delay'].at(self.name)
current = self.comm(x)
self.refs['out'].bind_cond(current)
return current
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])
[docs]
class FullProjAlignPreDS(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
.. code-block:: python
class EINet(bp.DynSysGroup):
def __init__(self):
super().__init__()
ne, ni = 3200, 800
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I,
delay=0.1,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I)
def update(self, inp):
self.E2E()
self.E2I()
self.I2E()
self.I2I()
self.E(inp)
self.I(inp)
return self.E.spike
model = EINet()
indices = bm.arange(1000)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)
Args:
pre: The pre-synaptic neuron group.
delay: The synaptic delay.
syn: The synaptic dynamics.
comm: The synaptic communication.
out: The synaptic output.
post: The post-synaptic neuron group.
name: str. The projection name.
mode: Mode. The computing mode.
"""
def __init__(
self,
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: DynamicalSystem,
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
post: DynamicalSystem,
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, DynamicalSystem)
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
self.comm = comm
self.syn = syn
# delay initialization
delay_cls = register_delay_by_return(pre)
delay_cls.register_entry(self.name, delay)
# output initialization
post.add_inp_fun(self.name, out, label=out_label)
# references
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
# unify the access
self.refs['syn'] = syn
self.refs['comm'] = comm
[docs]
def update(self):
spk = self.refs['delay'].at(self.name)
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
return g
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])