{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simulating 1-million-neuron networks with 1GB GPU memory\n", "\n", "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/large_scale_modeling/EI_net_with_1m_neurons.ipynb)\n", "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/large_scale_modeling/EI_net_with_1m_neurons.ipynb)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import brainpy as bp\n", "import brainpy.math as bm\n", "import jax\n", "\n", "# # if you want to see GPU memory usage\n", "# bm.disable_gpu_memory_preallocation()\n", "\n", "assert bp.__version__ >= '2.4.1'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "bm.set(dt=0.4)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "_default_g_max = dict(type='homo', value=1., prob=0.1, seed=123)\n", "_default_uniform = dict(type='uniform', w_low=0.1, w_high=1., prob=0.1, seed=123)\n", "_default_normal = dict(type='normal', w_mu=0.1, w_sigma=1., prob=0.1, seed=123)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Exponential(bp.synapses.TwoEndConn):\n", " def __init__(\n", " self,\n", " pre: bp.NeuGroup,\n", " post: bp.NeuGroup,\n", " output: bp.SynOut = bp.synouts.CUBA(),\n", " g_max_par=_default_g_max,\n", " delay_step=None,\n", " tau=8.0,\n", " method: str = 'exp_auto',\n", " name: str = None,\n", " mode: bm.Mode = None,\n", " ):\n", " super().__init__(pre, post, None, output=output, name=name, mode=mode)\n", " self.tau = tau\n", " self.g_max_par = g_max_par\n", " self.g = bp.init.variable_(bm.zeros, self.post.num, self.mode)\n", " self.delay_step = self.register_delay(f\"{self.pre.name}.spike\", delay_step, self.pre.spike)\n", " self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)\n", "\n", " def reset_state(self, batch_size=None):\n", " self.g.value = bp.init.variable_(bm.zeros, self.post.num, batch_size)\n", "\n", " def update(self):\n", " t = bp.share.load('t')\n", " dt = bp.share.load('dt')\n", " pre_spike = self.get_delay_data(f\"{self.pre.name}.spike\", self.delay_step)\n", " if self.g_max_par['type'] == 'homo':\n", " f = lambda s: bm.jitconn.event_mv_prob_homo(s,\n", " self.g_max_par['value'],\n", " conn_prob=self.g_max_par['prob'],\n", " shape=(self.pre.num, self.post.num),\n", " seed=self.g_max_par['seed'],\n", " transpose=True)\n", " elif self.g_max_par['type'] == 'uniform':\n", " f = lambda s: bm.jitconn.event_mv_prob_uniform(s,\n", " w_low=self.g_max_par['w_low'],\n", " w_high=self.g_max_par['w_high'],\n", " conn_prob=self.g_max_par['prob'],\n", " shape=(self.pre.num, self.post.num),\n", " seed=self.g_max_par['seed'],\n", " transpose=True)\n", " elif self.g_max_par['type'] == 'normal':\n", " f = lambda s: bm.jitconn.event_mv_prob_normal(s,\n", " w_mu=self.g_max_par['w_mu'],\n", " w_sigma=self.g_max_par['w_sigma'],\n", " conn_prob=self.g_max_par['prob'],\n", " shape=(self.pre.num, self.post.num),\n", " seed=self.g_max_par['seed'],\n", " transpose=True)\n", " else:\n", " raise ValueError\n", " if isinstance(self.mode, bm.BatchingMode):\n", " f = jax.vmap(f)\n", " post_vs = f(pre_spike)\n", " self.g.value = self.integral(self.g.value, t, dt) + post_vs\n", " return self.output(self.g)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class EINet(bp.DynSysGroup):\n", " def __init__(self, scale=1.0, method='exp_auto'):\n", " super().__init__()\n", " num_exc = int(3200 * scale)\n", " num_inh = int(800 * scale)\n", " pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.))\n", " self.N = bp.neurons.LIF(num_exc + num_inh, **pars, method=method)\n", " self.E = Exponential(self.N[:num_exc], self.N,\n", " g_max_par=dict(type='homo', value=0.6 / scale, prob=0.02, seed=123),\n", " tau=5., method=method, output=bp.synouts.COBA(E=0.))\n", " self.I = Exponential(self.N[num_exc:], self.N,\n", " g_max_par=dict(type='homo', value=6.7 / scale, prob=0.02, seed=12345),\n", " tau=10., method=method, output=bp.synouts.COBA(E=-80.))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "duration = 1e2\n", "net = EINet(scale=250)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6e54bb9bf3b4dfc813b942ac603c872", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/250 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "bp.visualize.raster_plot(runner.mon.ts, runner.mon['N.spike'], show=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 1 }