{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# *(Susin & Destexhe, 2021)*: Asynchronous Network\n", "\n", "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/oscillation_synchronization/Susin_Destexhe_2021_gamma_oscillation_AI.ipynb)\n", "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/oscillation_synchronization/Susin_Destexhe_2021_gamma_oscillation_AI.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of the paper:\n", "\n", "- Susin, Eduarda, and Alain Destexhe. \"Integration, coincidence detection and\n", " resonance in networks of spiking neurons expressing gamma oscillations and\n", " asynchronous states.\" PLoS computational biology 17.9 (2021): e1009416.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.864979700Z", "start_time": "2023-07-21T14:56:14.032870900Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from scipy.signal import kaiserord, lfilter, firwin, hilbert\n", "\n", "import brainpy as bp\n", "import brainpy.math as bm" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:57:06.248615900Z", "start_time": "2023-07-21T14:57:06.241284800Z" } }, "outputs": [ { "data": { "text/plain": [ "'2.4.3'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Table 1: specific neuron model parameters\n", "RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,\n", " E_e=0., E_i=-80.)\n", "FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,\n", " E_e=0., E_i=-80.)\n", "Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,\n", " E_e=0., E_i=-80.)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.898858400Z", "start_time": "2023-07-21T14:56:15.880843900Z" } }, "outputs": [], "source": [ "class AdEx(bp.dyn.NeuDyn):\n", " def __init__(\n", " self,\n", " size,\n", "\n", " # neuronal parameters\n", " Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,\n", " gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,\n", "\n", " # synaptic parameters\n", " tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,\n", "\n", " # other parameters\n", " name=None, method='exp_euler',\n", " V_initializer=bp.init.Uniform(-65, -50),\n", " w_initializer=bp.init.Constant(0.),\n", " ):\n", " super(AdEx, self).__init__(size=size, name=name)\n", "\n", " # neuronal parameters\n", " self.Vth = Vth\n", " self.delta = delta\n", " self.tau_ref = tau_ref\n", " self.tau_w = tau_w\n", " self.a = a\n", " self.b = b\n", " self.C = C\n", " self.gL = gL\n", " self.EL = EL\n", " self.V_reset = V_reset\n", " self.V_sp_th = V_sp_th\n", "\n", " # synaptic parameters\n", " self.tau_e = tau_e\n", " self.tau_i = tau_i\n", " self.E_e = E_e\n", " self.E_i = E_i\n", "\n", " # neuronal variables\n", " self.V = bp.init.variable_(V_initializer, self.num)\n", " self.w = bp.init.variable_(w_initializer, self.num)\n", " self.spike = bm.Variable(self.num, dtype=bool)\n", " self.refractory = bm.Variable(self.num, dtype=bool)\n", " self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)\n", "\n", " # synaptic parameters\n", " self.ge = bm.Variable(self.num)\n", " self.gi = bm.Variable(self.num)\n", "\n", " # integral\n", " self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)\n", "\n", " def dge(self, ge, t):\n", " return -ge / self.tau_e\n", "\n", " def dgi(self, gi, t):\n", " return -gi / self.tau_i\n", "\n", " def dV(self, V, t, w, ge, gi, Iext=None):\n", " I = ge * (self.E_e - V) + gi * (self.E_i - V)\n", " if Iext is not None: I += Iext\n", " dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)\n", " - w + self.gL * (self.EL - V) + I) / self.C\n", " return dVdt\n", "\n", " def dw(self, w, t, V):\n", " dwdt = (self.a * (V - self.EL) - w) / self.tau_w\n", " return dwdt\n", "\n", " def update(self, x=None):\n", " t = bp.share['t']\n", " dt = bp.share['dt']\n", " V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,\n", " t, Iext=x, dt=dt)\n", " refractory = (t - self.t_last_spike) <= self.tau_ref\n", " V = bm.where(refractory, self.V.value, V)\n", " spike = V >= self.V_sp_th\n", " self.V.value = bm.where(spike, self.V_reset, V)\n", " self.w.value = bm.where(spike, w + self.b, w)\n", " self.ge.value = ge\n", " self.gi.value = gi\n", " self.spike.value = spike\n", " self.refractory.value = bm.logical_or(refractory, spike)\n", " self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.911856400Z", "start_time": "2023-07-21T14:56:15.905858900Z" } }, "outputs": [], "source": [ "class AINet(bp.DynSysGroup):\n", " def __init__(self, ext_varied_rates, ext_weight=1., method='exp_euler', dt=bm.get_dt()):\n", " super(AINet, self).__init__()\n", "\n", " self.num_exc = 20000\n", " self.num_inh = 5000\n", " self.exc_syn_tau = 5. # ms\n", " self.inh_syn_tau = 5. # ms\n", " self.exc_syn_weight = 1. # nS\n", " self.inh_syn_weight = 5. # nS\n", " self.num_delay_step = int(1.5 / dt)\n", " self.ext_varied_rates = ext_varied_rates\n", "\n", " # neuronal populations\n", " RS_par_ = RS_par.copy()\n", " FS_par_ = FS_par.copy()\n", " RS_par_.update(Vth=-50, V_sp_th=-40)\n", " FS_par_.update(Vth=-50, V_sp_th=-40)\n", " self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)\n", " self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)\n", " self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1))\n", "\n", " # Poisson inputs\n", " self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='ge'),\n", " g_max=ext_weight)\n", " self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='ge'),\n", " g_max=ext_weight)\n", "\n", " # synaptic projections\n", " self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='ge'),\n", " g_max=self.exc_syn_weight,\n", " delay_step=self.num_delay_step)\n", " self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='ge'),\n", " g_max=self.exc_syn_weight,\n", " delay_step=self.num_delay_step)\n", " self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='gi'),\n", " g_max=self.inh_syn_weight,\n", " delay_step=self.num_delay_step)\n", " self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),\n", " output=bp.synouts.CUBA(target_var='gi'),\n", " g_max=self.inh_syn_weight,\n", " delay_step=self.num_delay_step)\n", "\n", " def change_freq(self):\n", " self.ext_pop.freqs[0] = self.ext_varied_rates[bp.share['i']]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.935059500Z", "start_time": "2023-07-21T14:56:15.912856100Z" } }, "outputs": [], "source": [ "def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):\n", " dt = bm.get_dt() if dt is None else dt\n", " t = 0\n", " num_gap = int(t_gap / dt)\n", " num_total = int(t_total / dt)\n", " num_transition = int(t_transition / dt)\n", "\n", " inputs = []\n", " ramp_up = np.linspace(c_low, c_high, num_transition)\n", " ramp_down = np.linspace(c_high, c_low, num_transition)\n", " plato_base = np.ones(num_gap) * c_low\n", " while t < num_total:\n", " num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)\n", " inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])\n", " t += (num_gap + num_transition + num_plato + num_transition)\n", " return bm.asarray(np.concatenate(inputs)[:num_total])\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.943060Z", "start_time": "2023-07-21T14:56:15.929062Z" } }, "outputs": [], "source": [ "def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):\n", " # sampling_space: in seconds (no units)\n", " # signal_time: in seconds (no units)\n", " # low_cut: in Hz (no units)(band to filter)\n", " # high_cut: in Hz (no units)(band to filter)\n", "\n", " signal = signal - np.mean(signal)\n", " width = 5.0 # The desired width in Hz of the transition from pass to stop\n", " ripple_db = 60.0 # The desired attenuation in the stop band, in dB.\n", " sampling_rate = 1. / sampling_space\n", " Nyquist = sampling_rate / 2.\n", "\n", " num_taps, beta = kaiserord(ripple_db, width / Nyquist)\n", " if num_taps % 2 == 0:\n", " num_taps = num_taps + 1 # Numtaps must be odd\n", " taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,\n", " pass_zero=False, scale=True)\n", " filtered_signal = lfilter(taps, 1.0, signal)\n", " delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase\n", " delay_index = int(np.floor(delay * sampling_rate))\n", " filtered_signal = filtered_signal[num_taps - 1:] # taking out the \"corrupted\" signal\n", " # correcting the delay and taking out the \"corrupted\" signal part\n", " filtered_time = signal_time[num_taps - 1:] - delay\n", " cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]\n", "\n", " # --------------------------------------------------------------------------\n", " # The hilbert transform are very slow when the signal has odd lenght,\n", " # This part check if the length is odd, and if this is the case it adds a zero in the end\n", " # of all the vectors related to the filtered Signal:\n", " if len(filtered_signal) % 2 != 0: # If the lengh is odd\n", " tmp1 = filtered_signal.tolist()\n", " tmp1.append(0)\n", " tmp2 = filtered_time.tolist()\n", " tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])\n", " tmp3 = cutted_signal.tolist()\n", " tmp3.append(0)\n", " filtered_signal = np.asarray(tmp1)\n", " filtered_time = np.asarray(tmp2)\n", " cutted_signal = np.asarray(tmp3)\n", " # --------------------------------------------------------------------------\n", "\n", " ht_filtered_signal = hilbert(filtered_signal)\n", " envelope = np.abs(ht_filtered_signal)\n", " phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians\n", "\n", " return filtered_time, filtered_signal, cutted_signal, envelope, phase\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.961064200Z", "start_time": "2023-07-21T14:56:15.944061600Z" } }, "outputs": [], "source": [ "def visualize_simulation_results(times, spikes, example_potentials, varied_rates,\n", " xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):\n", " fig, gs = bp.visualize.get_figure(7, 1, 1, 12)\n", " # 1. input firing rate\n", " ax = fig.add_subplot(gs[0])\n", " plt.plot(times, varied_rates)\n", " if xlim is None:\n", " xlim = (0, times[-1])\n", " ax.set_xlim(*xlim)\n", " ax.set_xticks([])\n", " ax.set_ylabel('External\\nRate (Hz)')\n", "\n", " # 2. inhibitory cell rater plot\n", " ax = fig.add_subplot(gs[1: 3])\n", " i = 0\n", " y_ticks = ([], [])\n", " for key, (sp_matrix, sp_type) in spikes.items():\n", " iis, sps = np.where(sp_matrix)\n", " tts = times[iis]\n", " plt.plot(tts, sps + i, '.', markersize=1, label=key)\n", " y_ticks[0].append(i + sp_matrix.shape[1] / 2)\n", " y_ticks[1].append(key)\n", " i += sp_matrix.shape[1]\n", " ax.set_xlim(*xlim)\n", " ax.set_xlabel('')\n", " ax.set_ylabel('Neuron Index')\n", " ax.set_xticks([])\n", " ax.set_yticks(*y_ticks)\n", " # ax.legend()\n", "\n", " # 3. example membrane potential\n", " ax = fig.add_subplot(gs[3: 5])\n", " for key, potential in example_potentials.items():\n", " vs = np.where(spikes[key][0][:, 0], 0, potential)\n", " plt.plot(times, vs, label=key)\n", " ax.set_xlim(*xlim)\n", " ax.set_xticks([])\n", " ax.set_ylabel('V (mV)')\n", " ax.legend()\n", "\n", " # 4. LFP\n", " ax = fig.add_subplot(gs[5:7])\n", " ax.set_xlim(*xlim)\n", " t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0\n", " t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)\n", " times = times[t1: t2]\n", " lfp = 0\n", " for sp_matrix, sp_type in spikes.values():\n", " lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)\n", " phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,\n", " bm.get_dt() * 1e-3)\n", " plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')\n", " plt.plot(phase_ts * 1e3, filtered, color='orange', label=\"Filtered LFP (30-50 Hz)\")\n", " plt.plot(phase_ts * 1e3, envelope, color='purple', label=\"Hilbert Envelope\")\n", " plt.legend(loc='best')\n", " plt.xlabel('Time (ms)')\n", "\n", " # save or show\n", " if filename:\n", " plt.savefig(filename, dpi=500)\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:15.974061100Z", "start_time": "2023-07-21T14:56:15.962061Z" } }, "outputs": [], "source": [ "def simulate_ai_net():\n", " duration = 2e3\n", " varied_rates = get_inputs(2., 2., 50., 150, 600, 1e3, duration)\n", "\n", " net = AINet(varied_rates, ext_weight=1.)\n", " runner = bp.DSRunner(\n", " net,\n", " inputs=net.change_freq,\n", " monitors={'FS.V0': lambda: net.fs_pop.V[0],\n", " 'RS.V0': lambda: net.rs_pop.V[0],\n", " 'FS.spike': lambda: net.fs_pop.spike,\n", " 'RS.spike': lambda: net.rs_pop.spike}\n", " )\n", " runner.run(duration)\n", "\n", " visualize_simulation_results(times=runner.mon.ts,\n", " spikes={'FS': (runner.mon['FS.spike'], 'inh'),\n", " 'RS': (runner.mon['RS.spike'], 'exc')},\n", " example_potentials={'FS': runner.mon['FS.V0'],\n", " 'RS': runner.mon['RS.V0']},\n", " varied_rates=varied_rates.to_numpy())\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-07-21T14:56:50.189660500Z", "start_time": "2023-07-21T14:56:15.975063Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2301e0487d6141e09241873234a03236", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "simulate_ai_net()" ] } ], "metadata": { "kernelspec": { "display_name": "brainpy", "language": "python", "name": "brainpy" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 1 }