{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# *(Yang, 2020)*: Dynamical system analysis for RNN\n", "\n", "[](https://colab.research.google.com/github/brainpy/examples/blob/main/recurrent_networks/Yang_2020_RNN_Analysis.ipynb)\n", "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/recurrent_networks/Yang_2020_RNN_Analysis.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of the paper:\n", "\n", "- Yang G R, Wang X J. Artificial neural networks for neuroscientists: A primer[J]. Neuron, 2020, 107(6): 1048-1070.\n", "\n", "The original implementation is based on PyTorch: https://github.com/gyyang/nn-brain/blob/master/RNN%2BDynamicalSystemAnalysis.ipynb" ] }, { "cell_type": "code", "execution_count": 161, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.819433800Z", "start_time": "2023-07-22T10:48:44.772536600Z" } }, "outputs": [], "source": [ "import brainpy as bp\n", "import brainpy.math as bm\n", "import brainpy_datasets as bd\n", "\n", "bp.math.set_platform('cpu')" ] }, { "cell_type": "code", "execution_count": 162, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.872101900Z", "start_time": "2023-07-22T10:48:44.787584100Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'2.4.3'" ] }, "execution_count": 162, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bp.__version__" ] }, { "cell_type": "code", "execution_count": 163, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.873659300Z", "start_time": "2023-07-22T10:48:44.803836200Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'0.0.0.6'" ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bd.__version__" ] }, { "cell_type": "code", "execution_count": 164, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.888349800Z", "start_time": "2023-07-22T10:48:44.819433800Z" }, "lines_to_next_cell": 2 }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.decomposition import PCA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will use supervised learning to train a recurrent neural network on a simple perceptual decision making task, and analyze the trained network using dynamical system analysis." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining a cognitive task" ] }, { "cell_type": "code", "execution_count": 165, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.932902Z", "start_time": "2023-07-22T10:48:44.835824500Z" }, "collapsed": false }, "outputs": [], "source": [ "dataset = bd.cognitive.RatePerceptualDecisionMaking()\n", "task = bd.cognitive.TaskLoader(dataset, batch_size=16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define a vanilla continuous-time recurrent network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we will define a continuous-time neural network but discretize it in time using the Euler method.\n", "\n", "\\begin{align}\n", " \\tau \\frac{d\\mathbf{r}}{dt} = -\\mathbf{r}(t) + f(W_r \\mathbf{r}(t) + W_x \\mathbf{x}(t) + \\mathbf{b}_r).\n", "\\end{align}\n", "\n", "This continuous-time system can then be discretized using the Euler method with a time step of $\\Delta t$, \n", "\n", "\\begin{align}\n", " \\mathbf{r}(t+\\Delta t) = \\mathbf{r}(t) + \\Delta \\mathbf{r} = \\mathbf{r}(t) + \\frac{\\Delta t}{\\tau}[-\\mathbf{r}(t) + f(W_r \\mathbf{r}(t) + W_x \\mathbf{x}(t) + \\mathbf{b}_r)].\n", "\\end{align}" ] }, { "cell_type": "code", "execution_count": 166, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.933935600Z", "start_time": "2023-07-22T10:48:44.856431Z" } }, "outputs": [], "source": [ "class RNN(bp.DynamicalSystem):\n", " def __init__(self, \n", " num_input, \n", " num_hidden, \n", " num_output, \n", " num_batch, \n", " dt=None, seed=None,\n", " w_ir=bp.init.KaimingNormal(scale=1.),\n", " w_rr=bp.init.KaimingNormal(scale=1.),\n", " w_ro=bp.init.KaimingNormal(scale=1.)):\n", " super(RNN, self).__init__()\n", "\n", " # parameters\n", " self.tau = 100\n", " self.num_batch = num_batch\n", " self.num_input = num_input\n", " self.num_hidden = num_hidden\n", " self.num_output = num_output\n", " if dt is None:\n", " self.alpha = 1\n", " else:\n", " self.alpha = dt / self.tau\n", " self.rng = bm.random.RandomState(seed=seed)\n", "\n", " # input weight\n", " self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))\n", "\n", " # recurrent weight\n", " bound = 1 / num_hidden ** 0.5\n", " self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))\n", " self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))\n", "\n", " # readout weight\n", " self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (num_hidden, num_output)))\n", " self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))\n", "\n", " self.reset_state(self.mode)\n", "\n", " def reset_state(self, batch_size):\n", " self.h = bp.init.variable_(bm.zeros, self.num_hidden, batch_size)\n", " self.o = bp.init.variable_(bm.zeros, self.num_output, batch_size)\n", "\n", " def cell(self, x, h):\n", " ins = x @ self.w_ir + h @ self.w_rr + self.b_rr\n", " state = h * (1 - self.alpha) + ins * self.alpha\n", " return bm.relu(state)\n", "\n", " def readout(self, h):\n", " return h @ self.w_ro + self.b_ro\n", "\n", " def update(self, x):\n", " self.h.value = self.cell(x, self.h)\n", " self.o.value = self.readout(self.h)\n", " return self.h.value, self.o.value\n", "\n", " def predict(self, xs):\n", " self.h[:] = 0.\n", " return bm.for_loop(self.update, xs)\n", "\n", " def loss(self, xs, ys):\n", " hs, os = self.predict(xs)\n", " os = os.reshape((-1, os.shape[-1]))\n", " loss = bp.losses.cross_entropy_loss(os, ys.flatten())\n", " return loss, os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the recurrent network on the decision-making task" ] }, { "cell_type": "code", "execution_count": 167, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.933935600Z", "start_time": "2023-07-22T10:48:44.873659300Z" } }, "outputs": [], "source": [ "# Instantiate the network and print information\n", "hidden_size = 64\n", "with bm.environment(mode=bm.TrainingMode(batch_size=16)):\n", " net = RNN(num_input=dataset.num_inputs,\n", " num_hidden=hidden_size,\n", " num_output=dataset.num_outputs,\n", " num_batch=task.batch_size,\n", " dt=dataset.dt)" ] }, { "cell_type": "code", "execution_count": 168, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:44.959876300Z", "start_time": "2023-07-22T10:48:44.891350100Z" } }, "outputs": [], "source": [ "# Adam optimizer\n", "opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())\n", "\n", "# gradient function\n", "grad_f = bm.grad(net.loss,\n", " grad_vars=net.train_vars().unique(),\n", " return_value=True,\n", " has_aux=True)\n", "\n", "# training function\n", "@bm.jit\n", "def train(xs, ys):\n", " grads, l, os = grad_f(xs, ys)\n", " opt.update(grads)\n", " return l, os" ] }, { "cell_type": "code", "execution_count": 169, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:47.166302Z", "start_time": "2023-07-22T10:48:44.933935600Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batch 1, Loss 0.2494, Acc 0.164\n", "Batch 2, Loss 0.0526, Acc 0.663\n", "Batch 3, Loss 0.0375, Acc 0.766\n", "Batch 4, Loss 0.0314, Acc 0.775\n", "Batch 5, Loss 0.0294, Acc 0.780\n", "Batch 6, Loss 0.0291, Acc 0.796\n", "Batch 7, Loss 0.0249, Acc 0.830\n", "Batch 8, Loss 0.0251, Acc 0.812\n", "Batch 9, Loss 0.0223, Acc 0.827\n", "Batch 10, Loss 0.0209, Acc 0.848\n", "Batch 11, Loss 0.0218, Acc 0.817\n", "Batch 12, Loss 0.0220, Acc 0.822\n", "Batch 13, Loss 0.0191, Acc 0.853\n", "Batch 14, Loss 0.0176, Acc 0.861\n", "Batch 15, Loss 0.0216, Acc 0.832\n", "Batch 16, Loss 0.0177, Acc 0.882\n", "Batch 17, Loss 0.0180, Acc 0.864\n", "Batch 18, Loss 0.0166, Acc 0.869\n", "Batch 19, Loss 0.0162, Acc 0.861\n", "Batch 20, Loss 0.0160, Acc 0.871\n" ] } ], "source": [ "running_acc = []\n", "running_loss = []\n", "for i_batch in range(20):\n", " for X, Y in task:\n", " # training\n", " loss, outputs = train(bm.asarray(X), bm.asarray(Y))\n", " # Compute performance\n", " output_np = np.asarray(bm.argmax(outputs, axis=-1)).flatten()\n", " labels_np = np.asarray(Y).flatten()\n", " ind = labels_np > 0 # 0: fixation, 1: choice 1, 2: choice 2\n", " running_loss.append(loss)\n", " running_acc.append(np.mean(labels_np[ind] == output_np[ind]))\n", " print(f'Batch {i_batch + 1}, Loss {np.mean(running_loss):0.4f}, Acc {np.mean(running_acc):0.3f}')\n", " running_loss = []\n", " running_acc = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize neural activity for in sample trials\n", "\n", "We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space." ] }, { "cell_type": "code", "execution_count": 170, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:47.227288900Z", "start_time": "2023-07-22T10:48:47.166302Z" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the neural activity: (Time points, Neurons): (2200, 64)\n" ] } ], "source": [ "num_trial = 100\n", "task = bd.cognitive.TaskLoader(dataset, batch_size=num_trial)\n", "inputs, trial_infos = task.get_batch()\n", "\n", "# reset the network state to match the required batch size\n", "net.reset_state(num_trial)\n", "\n", "# get the RNN activity\n", "rnn_activity, _ = net.predict(inputs)\n", "rnn_activity = np.asarray(rnn_activity)\n", "trial_infos = np.asarray(trial_infos)\n", "\n", "# Concatenate activity for PCA\n", "activity = rnn_activity.reshape(-1, hidden_size)\n", "print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)" ] }, { "cell_type": "code", "execution_count": 171, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:48:47.243974400Z", "start_time": "2023-07-22T10:48:47.227288900Z" } }, "outputs": [ { "data": { "text/html": [ "
PCA(n_components=2)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
PCA(n_components=2)