Source code for brainpy.dyn.projections.align_pre

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Union

from brainpy import math as bm, check
from brainpy.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return)
from brainpy.dynsys import DynamicalSystem, Projection
from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData)
from .utils import _get_return

__all__ = [
    'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
    'FullProjAlignPreSD', 'FullProjAlignPreDS',
]


def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None):
    _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}'
    if not delay_cls.has_bef_update(_syn_id):
        # delay
        delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name)
        # synapse
        syn_cls = syn_desc()
        # add to "after_updates"
        delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
    syn = delay_cls.get_bef_update(_syn_id).syn
    return syn


class _AlignPreMg(DynamicalSystem):
    def __init__(self, access, syn):
        super().__init__()
        self.access = access
        self.syn = syn

    def update(self, *args, **kwargs):
        return self.syn(self.access())

    def reset_state(self, *args, **kwargs):
        pass


def align_pre1_add_bef_update(syn_desc, pre):
    _syn_id = f'{syn_desc.identifier} // Delay'
    if not pre.has_aft_update(_syn_id):
        # "syn_cls" needs an instance of "ProjAutoDelay"
        syn_cls: SupportAutoDelay = syn_desc()
        delay_cls = init_delay_by_return(syn_cls.return_info())
        # add to "after_updates"
        pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls))
    delay_cls: Delay = pre.get_aft_update(_syn_id).delay
    syn = pre.get_aft_update(_syn_id).syn
    return delay_cls, syn


class _AlignPre(DynamicalSystem):
    def __init__(self, syn, delay=None):
        super().__init__()
        self.syn = syn
        self.delay = delay

    def update(self, x):
        if self.delay is None:
            return x >> self.syn
        else:
            return x >> self.syn >> self.delay

    def reset_state(self, *args, **kwargs):
        pass


[docs] class FullProjAlignPreSDMg(Projection): """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging. The ``full-chain`` means that the model needs to provide all information needed for a projection, including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the synapse states to the delay model, and finally computes the synaptic current. The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same parameters (such like time constants) will also share the same synaptic variables. Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation. This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather than the spiking. To facilitate the event-driven computation, please use align post projections. To simulate an E/I balanced network model: .. code-block:: python class EINet(bp.DynSysGroup): def __init__(self): super().__init__() ne, ni = 3200, 800 self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) def update(self, inp): self.E2E() self.E2I() self.I2E() self.I2I() self.E(inp) self.I(inp) return self.E.spike model = EINet() indices = bm.arange(1000) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) Args: pre: The pre-synaptic neuron group. syn: The synaptic dynamics. delay: The synaptic delay. comm: The synaptic communication. out: The synaptic output. post: The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, pre: DynamicalSystem, syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, DynamicalSystem) check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and delay initialization delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) delay_cls.register_entry(self.name, delay) # output initialization post.add_inp_fun(self.name, out, label=out_label) # references self.refs = dict() # invisible to ``self.nodes()`` self.refs['pre'] = pre self.refs['post'] = post self.refs['out'] = out self.refs['delay'] = delay_cls self.refs['syn'] = syn_cls # unify the access self.refs['comm'] = comm
[docs] def update(self, x=None): if x is None: x = self.refs['delay'].at(self.name) current = self.comm(x) self.refs['out'].bind_cond(current) return current
pre = property(lambda self: self.refs['pre']) post = property(lambda self: self.refs['post']) syn = property(lambda self: self.refs['syn']) delay = property(lambda self: self.refs['delay']) out = property(lambda self: self.refs['out'])
[docs] class FullProjAlignPreDSMg(Projection): """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging. The ``full-chain`` means that the model needs to provide all information needed for a projection, including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``. Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged. The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same parameters (such like time constants) will also share the same synaptic variables. Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation. This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather than the spiking. To facilitate the event-driven computation, please use align post projections. To simulate an E/I balanced network model: .. code-block:: python class EINet(bp.DynSysGroup): def __init__(self): super().__init__() ne, ni = 3200, 800 self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) def update(self, inp): self.E2E() self.E2I() self.I2E() self.I2I() self.E(inp) self.I(inp) return self.E.spike model = EINet() indices = bm.arange(1000) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) Args: pre: The pre-synaptic neuron group. delay: The synaptic delay. syn: The synaptic dynamics. comm: The synaptic communication. out: The synaptic output. post: The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: ParamDescriber[DynamicalSystem], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization delay_cls = register_delay_by_return(pre) # synapse initialization syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) # output initialization post.add_inp_fun(self.name, out, label=out_label) # references self.refs = dict() # invisible to `self.nodes()` self.refs['pre'] = pre self.refs['post'] = post self.refs['syn'] = syn_cls self.refs['out'] = out # unify the access self.refs['comm'] = comm
[docs] def update(self): x = _get_return(self.refs['syn'].return_info()) current = self.comm(x) self.refs['out'].bind_cond(current) return current
pre = property(lambda self: self.refs['pre']) post = property(lambda self: self.refs['post']) syn = property(lambda self: self.refs['syn']) out = property(lambda self: self.refs['out'])
[docs] class FullProjAlignPreSD(Projection): """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating. The ``full-chain`` means that the model needs to provide all information needed for a projection, including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the synapse states to the delay model, and finally computes the synaptic current. Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation. This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather than the spiking. To facilitate the event-driven computation, please use align post projections. To simulate an E/I balanced network model: .. code-block:: python class EINet(bp.DynSysGroup): def __init__(self): super().__init__() ne, ni = 3200, 800 self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) def update(self, inp): self.E2E() self.E2I() self.I2E() self.I2I() self.E(inp) self.I(inp) return self.E.spike model = EINet() indices = bm.arange(1000) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) Args: pre: The pre-synaptic neuron group. syn: The synaptic dynamics. delay: The synaptic delay. comm: The synaptic communication. out: The synaptic output. post: The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, pre: DynamicalSystem, syn: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, DynamicalSystem) check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and delay initialization delay_cls = init_delay_by_return(syn.return_info()) delay_cls.register_entry(self.name, delay) pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) # output initialization post.add_inp_fun(self.name, out, label=out_label) # references self.refs = dict() # invisible to ``self.nodes()`` self.refs['pre'] = pre self.refs['post'] = post self.refs['out'] = out self.refs['delay'] = delay_cls self.refs['syn'] = syn # unify the access self.refs['comm'] = comm
[docs] def update(self, x=None): if x is None: x = self.refs['delay'].at(self.name) current = self.comm(x) self.refs['out'].bind_cond(current) return current
pre = property(lambda self: self.refs['pre']) post = property(lambda self: self.refs['post']) syn = property(lambda self: self.refs['syn']) delay = property(lambda self: self.refs['delay']) out = property(lambda self: self.refs['out'])
[docs] class FullProjAlignPreDS(Projection): """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating. The ``full-chain`` means that the model needs to provide all information needed for a projection, including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged. The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation. This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather than the spiking. To facilitate the event-driven computation, please use align post projections. To simulate an E/I balanced network model: .. code-block:: python class EINet(bp.DynSysGroup): def __init__(self): super().__init__() ne, ni = 3200, 800 self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) def update(self, inp): self.E2E() self.E2I() self.I2E() self.I2I() self.E(inp) self.I(inp) return self.E.spike model = EINet() indices = bm.arange(1000) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) Args: pre: The pre-synaptic neuron group. delay: The synaptic delay. syn: The synaptic dynamics. comm: The synaptic communication. out: The synaptic output. post: The post-synaptic neuron group. name: str. The projection name. mode: Mode. The computing mode. """ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: DynamicalSystem, comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(syn, DynamicalSystem) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm self.syn = syn # delay initialization delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # output initialization post.add_inp_fun(self.name, out, label=out_label) # references self.refs = dict() # invisible to ``self.nodes()`` self.refs['pre'] = pre self.refs['post'] = post self.refs['out'] = out self.refs['delay'] = delay_cls # unify the access self.refs['syn'] = syn self.refs['comm'] = comm
[docs] def update(self): spk = self.refs['delay'].at(self.name) g = self.comm(self.syn(spk)) self.refs['out'].bind_cond(g) return g
pre = property(lambda self: self.refs['pre']) post = property(lambda self: self.refs['post']) delay = property(lambda self: self.refs['delay']) out = property(lambda self: self.refs['out'])