Source code for brainpy._src.dyn.projections.delta

from typing import Optional, Union

from brainpy import math as bm, check
from brainpy._src.delay import (delay_identifier, register_delay_by_return)
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, SupportAutoDelay)

__all__ = [
  'HalfProjDelta', 'FullProjDelta',
]


class _Delta:
  def __init__(self):
    self._cond = None

  def bind_cond(self, cond):
    self._cond = cond

  def __call__(self, *args, **kwargs):
    r = self._cond
    return r


[docs] class HalfProjDelta(Projection): """Defining the half-part of the synaptic projection for the Delta synapse model. The synaptic projection requires the input is the spiking data, otherwise the synapse is not the Delta synapse model. The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs the manual providing of the spiking input. **Model Descriptions** .. math:: I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, :math:`C` the set of neurons connected to the post-synaptic neuron, and :math:`D` the transmission delay of chemical synapses. For simplicity, the rise and decay phases of post-synaptic currents are omitted in this model. **Code Examples** .. code-block:: import brainpy as bp import brainpy.math as bm class Net(bp.DynamicalSystem): def __init__(self): super().__init__() self.pre = bp.dyn.PoissonGroup(10, 100.) self.post = bp.dyn.LifRef(1) self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) def update(self): self.syn(self.pre()) self.post() return self.post.V.value net = Net() indices = bm.arange(1000).to_numpy() vs = bm.for_loop(net.step_run, indices, progress_bar=True) bp.visualize.line_plot(indices, vs, show=True) Args: comm: DynamicalSystem. The synaptic communication. post: DynamicalSystem. The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, comm: DynamicalSystem, post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(comm, DynamicalSystem) check.is_instance(post, DynamicalSystem) self.comm = comm # output initialization out = _Delta() post.add_inp_fun(self.name, out, category='delta') # references self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` self.refs['comm'] = comm # unify the access
[docs] def update(self, x): # call the communication current = self.comm(x) # bind the output self.refs['out'].bind_cond(current) # return the current, if needed return current
[docs] class FullProjDelta(Projection): """Full-chain of the synaptic projection for the Delta synapse model. The synaptic projection requires the input is the spiking data, otherwise the synapse is not the Delta synapse model. The ``full-chain`` means that the model needs to provide all information needed for a projection, including ``pre`` -> ``delay`` -> ``comm`` -> ``post``. **Model Descriptions** .. math:: I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, :math:`C` the set of neurons connected to the post-synaptic neuron, and :math:`D` the transmission delay of chemical synapses. For simplicity, the rise and decay phases of post-synaptic currents are omitted in this model. **Code Examples** .. code-block:: import brainpy as bp import brainpy.math as bm class Net(bp.DynamicalSystem): def __init__(self): super().__init__() self.pre = bp.dyn.PoissonGroup(10, 100.) self.post = bp.dyn.LifRef(1) self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) def update(self): self.syn() self.pre() self.post() return self.post.V.value net = Net() indices = bm.arange(1000).to_numpy() vs = bm.for_loop(net.step_run, indices, progress_bar=True) bp.visualize.line_plot(indices, vs, show=True) Args: pre: The pre-synaptic neuron group. delay: The synaptic delay. comm: DynamicalSystem. The synaptic communication. post: DynamicalSystem. The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # output initialization out = _Delta() post.add_inp_fun(self.name, out, category='delta') # references self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` self.refs['comm'] = comm # unify the access self.refs['delay'] = pre.get_aft_update(delay_identifier)
[docs] def update(self): # get delay x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) # call the communication current = self.comm(x) # bind the output self.refs['out'].bind_cond(current) # return the current, if needed return current