Source code for brainpy._src.dyn.projections.plasticity

from typing import Optional, Callable, Union

from brainpy import math as bm, check
from brainpy._src.delay import register_delay_by_return
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay,
                                BindCondData, AlignPost, SupportSTDP)
from brainpy.types import ArrayType
from .align_post import (align_post_add_bef_update, )
from .align_pre import (align_pre2_add_bef_update, )
from .utils import (_get_return, )

__all__ = [
  'STDP_Song2000',
]


def _init_trace_by_align_pre2(
    target: DynamicalSystem,
    delay: Union[None, int, float],
    syn: ParamDescriber[DynamicalSystem],
):
  """Calculate the trace of the target by reusing the existing connections."""
  check.is_instance(target, DynamicalSystem)
  check.is_instance(syn, ParamDescriber[DynamicalSystem])
  # delay initialization
  delay_cls = register_delay_by_return(target)
  # synapse initialization
  syn = align_pre2_add_bef_update(syn, delay, delay_cls)
  return syn


[docs] class STDP_Song2000(Projection): r"""Spike-time-dependent plasticity proposed by (Song, et. al, 2000). This model filters the synaptic currents according to the variables: :math:`w`. .. math:: I_{syn}^+(t) = I_{syn}^-(t) * w where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse, the conductance of the synapse will increase w. The dynamics of :math:`w` is governed by the following equation: .. math:: \begin{aligned} \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ \end{aligned} where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. Here is an example of the usage of this class:: import brainpy as bp import brainpy.math as bm class STDPNet(bp.DynamicalSystem): def __init__(self, num_pre, num_post): super().__init__() self.pre = bp.dyn.LifRef(num_pre) self.post = bp.dyn.LifRef(num_post) self.syn = bp.dyn.STDP_Song2000( pre=self.pre, delay=1., comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), weight=bp.init.Uniform(max_val=0.1)), syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), post=self.post, tau_s=16.8, tau_t=33.7, A1=0.96, A2=0.53, ) def update(self, I_pre, I_post): self.syn() self.pre(I_pre) self.post(I_post) conductance = self.syn.refs['syn'].g Apre = self.syn.refs['pre_trace'].g Apost = self.syn.refs['post_trace'].g current = self.post.sum_inputs(self.post.V) return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight duration = 300. I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) net = STDPNet(1, 1) def run(i, I_pre, I_post): pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) return pre_spike, post_spike, g, Apre, Apost, current, W indices = bm.arange(0, duration, bm.dt) pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) Args: tau_s: float. The time constant of :math:`A_{pre}`. tau_t: float. The time constant of :math:`A_{post}`. A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. W_max: float. The maximum weight. W_min: float. The minimum weight. pre: DynamicalSystem. The pre-synaptic neuron group. delay: int, float. The pre spike delay length. (ms) syn: DynamicalSystem. The synapse model. comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. out: DynamicalSystem. The synaptic current output models. post: DynamicalSystem. The post-synaptic neuron group. out_label: str. The output label. name: str. The model name. """ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: ParamDescriber[DynamicalSystem], comm: JointType[DynamicalSystem, SupportSTDP], out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], post: DynamicalSystem, # synapse parameters tau_s: Union[float, ArrayType, Callable] = 16.8, tau_t: Union[float, ArrayType, Callable] = 33.7, A1: Union[float, ArrayType, Callable] = 0.96, A2: Union[float, ArrayType, Callable] = 0.53, W_max: Optional[float] = None, W_min: Optional[float] = None, # others out_label: Optional[str] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.pre_num = pre.num self.post_num = post.num self.comm = comm self._is_align_post = issubclass(syn.cls, AlignPost) # delay initialization delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # synapse and output initialization if self._is_align_post: syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) else: syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') out_cls = out() post.add_inp_fun(self.name, out_cls, label=out_label) # references self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` self.refs['delay'] = delay_cls self.refs['syn'] = syn_cls # invisible to ``self.node()`` self.refs['out'] = out_cls # invisible to ``self.node()`` self.refs['comm'] = comm # tracing pre-synaptic spikes using Exponential model self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) # tracing post-synaptic spikes using Exponential model self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # synapse parameters self.W_max = W_max self.W_min = W_min self.tau_s = tau_s self.tau_t = tau_t self.A1 = A1 self.A2 = A2 pre = property(lambda self: self.refs['pre']) post = property(lambda self: self.refs['post']) syn = property(lambda self: self.refs['syn']) delay = property(lambda self: self.refs['delay']) out = property(lambda self: self.refs['out'])
[docs] def update(self): # pre-synaptic spikes pre_spike = self.refs['delay'].at(self.name) # spike # pre-synaptic variables if self._is_align_post: # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance x = pre_spike else: # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable # post spikes if not hasattr(self.refs['post'], 'spike'): raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') post_spike = self.refs['post'].spike # weight updates Apost = self.refs['post_trace'].g self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max) Apre = self.refs['pre_trace'].g self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max) # synaptic currents current = self.comm(x) if self._is_align_post: self.refs['syn'].add_current(current) # synapse post current else: self.refs['out'].bind_cond(current) # align pre return current
# class PairedSTDP(Projection): # r"""Paired spike-time-dependent plasticity model. # # This model filters the synaptic currents according to the variables: :math:`w`. # # .. math:: # # I_{syn}^+(t) = I_{syn}^-(t) * w # # where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before # and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse, # the conductance of the synapse will increase w. # # The dynamics of :math:`w` is governed by the following equation: # # .. math:: # # \begin{aligned} # \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ # \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ # \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ # \end{aligned} # # where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment # of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. # # Here is an example of the usage of this class:: # # import brainpy as bp # import brainpy.math as bm # # class STDPNet(bp.DynamicalSystem): # def __init__(self, num_pre, num_post): # super().__init__() # self.pre = bp.dyn.LifRef(num_pre) # self.post = bp.dyn.LifRef(num_post) # self.syn = bp.dyn.STDP_Song2000( # pre=self.pre, # delay=1., # comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), # weight=bp.init.Uniform(max_val=0.1)), # syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), # out=bp.dyn.COBA.desc(E=0.), # post=self.post, # tau_s=16.8, # tau_t=33.7, # A1=0.96, # A2=0.53, # ) # # def update(self, I_pre, I_post): # self.syn() # self.pre(I_pre) # self.post(I_post) # conductance = self.syn.refs['syn'].g # Apre = self.syn.refs['pre_trace'].g # Apost = self.syn.refs['post_trace'].g # current = self.post.sum_inputs(self.post.V) # return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight # # duration = 300. # I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], # [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) # I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], # [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) # # net = STDPNet(1, 1) # def run(i, I_pre, I_post): # pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) # return pre_spike, post_spike, g, Apre, Apost, current, W # # indices = bm.arange(0, duration, bm.dt) # pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) # # Args: # tau_s: float. The time constant of :math:`A_{pre}`. # tau_t: float. The time constant of :math:`A_{post}`. # A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. # A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. # W_max: float. The maximum weight. # W_min: float. The minimum weight. # pre: DynamicalSystem. The pre-synaptic neuron group. # delay: int, float. The pre spike delay length. (ms) # syn: DynamicalSystem. The synapse model. # comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. # out: DynamicalSystem. The synaptic current output models. # post: DynamicalSystem. The post-synaptic neuron group. # out_label: str. The output label. # name: str. The model name. # """ # # def __init__( # self, # pre: JointType[DynamicalSystem, SupportAutoDelay], # delay: Union[None, int, float], # syn: ParamDescriber[DynamicalSystem], # comm: JointType[DynamicalSystem, SupportSTDP], # out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], # post: DynamicalSystem, # # synapse parameters # tau_s: float = 16.8, # tau_t: float = 33.7, # lambda_: float = 0.96, # alpha: float = 0.53, # mu: float = 0.53, # W_max: Optional[float] = None, # W_min: Optional[float] = None, # # others # out_label: Optional[str] = None, # name: Optional[str] = None, # mode: Optional[bm.Mode] = None, # ): # super().__init__(name=name, mode=mode) # # # synaptic models # check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) # check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) # check.is_instance(syn, ParamDescriber[DynamicalSystem]) # check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) # check.is_instance(post, DynamicalSystem) # self.pre_num = pre.num # self.post_num = post.num # self.comm = comm # self._is_align_post = issubclass(syn.cls, AlignPost) # # # delay initialization # delay_cls = register_delay_by_return(pre) # delay_cls.register_entry(self.name, delay) # # # synapse and output initialization # if self._is_align_post: # syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, # proj_name=self.name) # else: # syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') # out_cls = out() # add_inp_fun(out_label, self.name, out_cls, post) # # # references # self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` # self.refs['delay'] = delay_cls # self.refs['syn'] = syn_cls # invisible to ``self.node()`` # self.refs['out'] = out_cls # invisible to ``self.node()`` # self.refs['comm'] = comm # # # tracing pre-synaptic spikes using Exponential model # self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) # # # tracing post-synaptic spikes using Exponential model # self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # # # synapse parameters # self.W_max = W_max # self.W_min = W_min # self.tau_s = tau_s # self.tau_t = tau_t # self.A1 = A1 # self.A2 = A2 # # def update(self): # # pre-synaptic spikes # pre_spike = self.refs['delay'].at(self.name) # spike # # pre-synaptic variables # if self._is_align_post: # # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance # x = pre_spike # else: # # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance # x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable # # # post spikes # if not hasattr(self.refs['post'], 'spike'): # raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') # post_spike = self.refs['post'].spike # # # weight updates # Apost = self.refs['post_trace'].g # self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max) # Apre = self.refs['pre_trace'].g # self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max) # # # synaptic currents # current = self.comm(x) # if self._is_align_post: # self.refs['syn'].add_current(current) # synapse post current # else: # self.refs['out'].bind_cond(current) # align pre # return current