{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# *(Joglekar, et. al, 2018)*: Inter-areal Balanced Amplification Figure 5\n", "\n", "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/large_scale_modeling/Joglekar_2018_InterAreal_Balanced_Amplification_figure5.ipynb)\n", "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/large_scale_modeling/Joglekar_2018_InterAreal_Balanced_Amplification_figure5.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of the figure 5 of:\n", "\n", "- Joglekar, Madhura R., et al. \"Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.\" Neuron 98.1 (2018): 222-234." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import brainpy as bp\n", "import brainpy.math as bm\n", "from brainpy import neurons\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from jax import vmap\n", "from scipy.io import loadmat" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# This model should be run on a GPU device\n", "\n", "bm.set_platform('gpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bp.__version__" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class MultiAreaNet(bp.Network):\n", " def __init__(\n", " self, hier, conn, delay_mat, muIE=0.0475, muEE=.0375, wII=.075,\n", " wEE=.01, wIE=.075, wEI=.0375, extE=15.4, extI=14.0, alpha=4., seed=None,\n", " ):\n", " super(MultiAreaNet, self).__init__()\n", "\n", " # data\n", " self.hier = hier\n", " self.conn = conn\n", " self.delay_mat = delay_mat\n", "\n", " # parameters\n", " self.muIE = muIE\n", " self.muEE = muEE\n", " self.wII = wII\n", " self.wEE = wEE\n", " self.wIE = wIE\n", " self.wEI = wEI\n", " self.extE = extE\n", " self.extI = extI\n", " self.alpha = alpha\n", " num_area = hier.size\n", " self.num_area = num_area\n", "\n", " # neuron models\n", " self.E = neurons.LIF((num_area, 1600),\n", " V_th=-50., V_reset=-60.,\n", " V_rest=-70., tau=20., tau_ref=2.,\n", " noise=3. / bm.sqrt(20.),\n", " V_initializer=bp.init.Uniform(-70., -50.),\n", " method='exp_auto',\n", " keep_size=True,\n", " ref_var=True)\n", " self.I = neurons.LIF((num_area, 400), V_th=-50., V_reset=-60.,\n", " V_rest=-70., tau=10., tau_ref=2., noise=3. / bm.sqrt(10.),\n", " V_initializer=bp.init.Uniform(-70., -50.),\n", " method='exp_auto',\n", " keep_size=True,\n", " ref_var=True)\n", "\n", " # delays\n", " self.intra_delay_step = int(2. / bm.get_dt())\n", " self.E_delay_steps = bm.asarray(delay_mat.T / bm.get_dt(), dtype=int)\n", " bm.fill_diagonal(self.E_delay_steps, self.intra_delay_step)\n", " self.Edelay = bm.LengthDelay(self.E.spike, delay_len=int(self.E_delay_steps.max()))\n", " self.Idelay = bm.LengthDelay(self.I.spike, delay_len=self.intra_delay_step)\n", "\n", " # synapse model\n", " syn_fun = lambda pre_spike, weight, conn_mat: weight * (pre_spike @ conn_mat)\n", " self.f_E_current = vmap(syn_fun)\n", " self.f_I_current = vmap(syn_fun, in_axes=(0, None, 0))\n", "\n", " # synapses from I\n", " self.intra_I2E_conn = bm.random.random((num_area, 400, 1600)) < 0.1\n", " self.intra_I2I_conn = bm.random.random((num_area, 400, 400)) < 0.1\n", " self.intra_I2E_weight = -wEI\n", " self.intra_I2I_weight = -wII\n", "\n", " # synapses from E\n", " self.E2E_conns = [bm.random.random((num_area, 1600, 1600)) < 0.1 for _ in range(num_area)]\n", " self.E2I_conns = [bm.random.random((num_area, 1600, 400)) < 0.1 for _ in range(num_area)]\n", " self.E2E_weights = (1 + alpha * hier) * muEE * conn.T # inter-area connections\n", " bm.fill_diagonal(self.E2E_weights, (1 + alpha * hier) * wEE) # intra-area connections\n", " self.E2I_weights = (1 + alpha * hier) * muIE * conn.T # inter-area connections\n", " bm.fill_diagonal(self.E2I_weights, (1 + alpha * hier) * wIE) # intra-area connections\n", "\n", " def update(self, v1_input):\n", " self.E.input[0] += v1_input\n", " self.E.input += self.extE\n", " self.I.input += self.extI\n", " E_not_ref = bm.logical_not(self.E.refractory)\n", " I_not_ref = bm.logical_not(self.I.refractory)\n", "\n", " # synapses from E\n", " for i in range(self.num_area):\n", " delayed_E_spikes = self.Edelay(self.E_delay_steps[i], i).astype(float)\n", " current = self.f_E_current(delayed_E_spikes, self.E2E_weights[i], self.E2E_conns[i])\n", " self.E.V += current * E_not_ref # E2E\n", " current = self.f_E_current(delayed_E_spikes, self.E2I_weights[i], self.E2I_conns[i])\n", " self.I.V += current * I_not_ref # E2I\n", "\n", " # synapses from I\n", " delayed_I_spikes = self.Idelay(self.intra_delay_step).astype(float)\n", " current = self.f_I_current(delayed_I_spikes, self.intra_I2E_weight, self.intra_I2E_conn)\n", " self.E.V += current * E_not_ref # I2E\n", " current = self.f_I_current(delayed_I_spikes, self.intra_I2I_weight, self.intra_I2I_conn)\n", " self.I.V += current * I_not_ref # I2I\n", "\n", " # updates\n", " self.Edelay.update(self.E.spike)\n", " self.Idelay.update(self.I.spike)\n", " self.E.update()\n", " self.I.update()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def raster_plot(xValues, yValues, duration):\n", " ticks = np.round(np.arange(0, 29) + 0.5, 2)\n", " areas = ['V1', 'V2', 'V4', 'DP', 'MT', '8m', '5', '8l', 'TEO', '2', 'F1',\n", " 'STPc', '7A', '46d', '10', '9/46v', '9/46d', 'F5', 'TEpd', 'PBr',\n", " '7m', '7B', 'F2', 'STPi', 'PROm', 'F7', '8B', 'STPr', '24c']\n", " N = len(ticks)\n", " plt.figure(figsize=(8, 6))\n", " plt.plot(xValues, yValues / (4 * 400), '.', markersize=1)\n", " plt.plot([0, duration], np.arange(N + 1).repeat(2).reshape(-1, 2).T, 'k-')\n", " plt.ylabel('Area')\n", " plt.yticks(np.arange(N))\n", " plt.xlabel('Time [ms]')\n", " plt.ylim(0, N)\n", " plt.yticks(ticks, areas)\n", " plt.xlim(0, duration)\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# hierarchy values\n", "hierVals = loadmat('Joglekar_2018_data/hierValspython.mat')\n", "hierValsnew = hierVals['hierVals'].flatten()\n", "hier = bm.asarray(hierValsnew / max(hierValsnew)) # hierarchy normalized.\n", "\n", "# fraction of labeled neurons\n", "flnMatp = loadmat('Joglekar_2018_data/efelenMatpython.mat')\n", "conn = bm.asarray(flnMatp['flnMatpython'].squeeze()) # fln values..Cij is strength from j to i\n", "\n", "# Distance\n", "speed = 3.5 # axonal conduction velocity\n", "distMatp = loadmat('Joglekar_2018_data/subgraphWiring29.mat')\n", "distMat = distMatp['wiring'].squeeze() # distances between areas values..\n", "delayMat = bm.asarray(distMat / speed)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "pars = dict(extE=14.2, extI=14.7, wII=.075, wEE=.01, wIE=.075, wEI=.0375, muEE=.0375, muIE=0.0475)\n", "inps = dict(value=15, duration=150)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "inputs, length = bp.inputs.section_input(values=[0, inps['value'], 0.],\n", " durations=[300., inps['duration'], 500],\n", " return_length=True)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6751d77a59ed4f6cb4b89a10c571372c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/9500 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "times, indices = np.where(runner.mon['E.spike'])\n", "times = runner.mon.ts[times]\n", "raster_plot(times, indices, length)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 1 }