{ "cells": [ { "cell_type": "markdown", "id": "46250053", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Integrator RNN Model\n", "\n", "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/recurrent_networks/integrator_rnn.ipynb)\n", "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/recurrent_networks/integrator_rnn.ipynb)" ] }, { "cell_type": "markdown", "id": "3f57ee5e", "metadata": {}, "source": [ "In this notebook, we train a vanilla RNN to integrate white noise. This example is useful on its own to understand how RNN training works." ] }, { "cell_type": "code", "execution_count": 1, "id": "1d5361e4", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.150086600Z", "start_time": "2023-07-22T10:12:06.319898300Z" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "import brainpy as bp\n", "import brainpy.math as bm\n", "\n", "bm.set_environment(bm.training_mode)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.167543600Z", "start_time": "2023-07-22T10:12:07.150086600Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'2.4.3'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bp.__version__" ] }, { "cell_type": "markdown", "id": "19cb5aaf", "metadata": { "lines_to_next_cell": 2 }, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": 3, "id": "a3df36a0", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.214464800Z", "start_time": "2023-07-22T10:12:07.167543600Z" } }, "outputs": [], "source": [ "dt = 0.04\n", "num_step = int(1.0 / dt)\n", "num_batch = 128" ] }, { "cell_type": "markdown", "id": "4c0d1cac", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 4, "id": "1993c793", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.217442200Z", "start_time": "2023-07-22T10:12:07.183211800Z" } }, "outputs": [], "source": [ "@bm.jit(static_argnames=['batch_size'])\n", "def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):\n", " # Create the white noise input\n", " sample = bm.random.normal(size=(batch_size, 1, 1))\n", " bias = mean * 2.0 * (sample - 0.5)\n", " samples = bm.random.normal(size=(batch_size, num_step, 1))\n", " noise_t = scale / dt ** 0.5 * samples\n", " inputs = bias + noise_t\n", " targets = bm.cumsum(inputs, axis=1)\n", " return inputs, targets" ] }, { "cell_type": "code", "execution_count": 5, "id": "9a4c53b6", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.217442200Z", "start_time": "2023-07-22T10:12:07.214464800Z" } }, "outputs": [], "source": [ "def train_data():\n", " for _ in range(100):\n", " yield build_inputs_and_targets(batch_size=num_batch)" ] }, { "cell_type": "markdown", "id": "6c2b759d", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 6, "id": "1de8d486", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:07.812586800Z", "start_time": "2023-07-22T10:12:07.217442200Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "class RNN(bp.DynamicalSystem):\n", " def __init__(self, num_in, num_hidden):\n", " super(RNN, self).__init__()\n", " self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)\n", " self.out = bp.layers.Dense(num_hidden, 1)\n", "\n", " def update(self, x):\n", " return self.out(self.rnn(x))\n", "\n", "model = RNN(1, 100)" ] }, { "cell_type": "markdown", "id": "5aac4126", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 7, "id": "9e63f9a7", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:25.596371Z", "start_time": "2023-07-22T10:12:07.812586800Z" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train 200 steps, use 1.9202 s, loss 0.5666112303733826\n", "Train 400 steps, use 1.1134 s, loss 0.024968581274151802\n", "Train 600 steps, use 1.1003 s, loss 0.02216172218322754\n", "Train 800 steps, use 1.1105 s, loss 0.021766066551208496\n", "Train 1000 steps, use 1.1119 s, loss 0.021580681204795837\n", "Train 1200 steps, use 1.1100 s, loss 0.0214251521974802\n", "Train 1400 steps, use 1.1109 s, loss 0.02127397060394287\n", "Train 1600 steps, use 1.1103 s, loss 0.02117355354130268\n", "Train 1800 steps, use 1.1075 s, loss 0.021074647083878517\n", "Train 2000 steps, use 1.0969 s, loss 0.020986948162317276\n", "Train 2200 steps, use 1.0962 s, loss 0.02090354822576046\n", "Train 2400 steps, use 1.1340 s, loss 0.020846327766776085\n", "Train 2600 steps, use 1.2002 s, loss 0.020739112049341202\n", "Train 2800 steps, use 1.1426 s, loss 0.020664572715759277\n", "Train 3000 steps, use 1.0991 s, loss 0.020596308633685112\n" ] } ], "source": [ "# define loss function\n", "def loss(predictions, targets, l2_reg=2e-4):\n", " mse = bp.losses.mean_squared_error(predictions, targets)\n", " l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2\n", " return mse + l2\n", "\n", "\n", "# define optimizer\n", "lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)\n", "opt = bp.optim.Adam(lr=lr, eps=1e-1)\n", "\n", "\n", "# create a trainer\n", "trainer = bp.BPTT(model, loss_fun=loss, optimizer=opt)\n", "trainer.fit(train_data,\n", " num_epoch=30,\n", " num_report=200)" ] }, { "cell_type": "code", "execution_count": 8, "id": "6a4f7da4", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:25.659171500Z", "start_time": "2023-07-22T10:12:25.596371Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqWklEQVR4nO3df3Rb9X3/8de1ZMmOY7tN/CNxcYzZSZM0pi2xN+pAylrA+wYOG+s60mZ12Cmc4dNQMN74kjTtYPmumLI2DR044J2e9qQd1GeDMrqlp3gtC2ahAxyH0kKBrgG7wa7jQC07IZIt3e8ftmTLlmxJlu7Hjp6PM51GV/devaVk9ot735/Px7Jt2xYAAIAhOaYLAAAA2Y0wAgAAjCKMAAAAowgjAADAKMIIAAAwijACAACMIowAAACjCCMAAMAot+kCEhEKhfTWW2+psLBQlmWZLgcAACTAtm2NjIyooqJCOTnxr38siTDy1ltvqbKy0nQZAAAgBX19fTrvvPPivr4kwkhhYaGkiQ9TVFRkuBoAAJAIn8+nysrKyO/xeJZEGAnfmikqKiKMAACwxMzXYkEDKwAAMIowAgAAjCKMAAAAowgjAADAKMIIAAAwijACAACMIowAAACjCCMAAMAowggAADCKMAIAAIwijAAAAKMIIwAAwKisDiPf7/mNvvT4z/X8G2+bLgUAgKyV1WHkx68M6js/fVMv9v3OdCkAAGStrA4jZYV5kqSTo37DlQAAkL2yOoyUFnolSSd9hBEAAEzJ6jBSNhlGBkcIIwAAmJLdYaQoHEbOGq4EAIDsldVhJHKbhisjAAAYk9VhJNzA+s6ZMQXGQ4arAQAgO2V1GHnvslzluixJjKgBAMCUrA4jlmWpdDm3agAAMCmrw4g01Tcy6KOJFQAAEwgjk30jDO8FAMCMrA8jU8N7CSMAAJhAGGF4LwAARmV9GJmaa4SeEQAATMj6MFJGzwgAAEYRRiKjaQgjAACYQBiZbGAdGvUrFLINVwMAQPbJ+jCysmAijIyHbL1zJmC4GgAAsk/WhxGPO0crCjyS6BsBAMCErA8j0rS+EcIIAACOI4xo+vBewggAAE4jjGja+jTMNQIAgOMII5o21wjDewEAcBxhREwJDwCASYQRTc01QhgBAMB5hBFJpcvpGQEAwBTCiKSyItanAQDAFMKIpnpGzgSCOu0fN1wNAADZhTAiqcDr1jKPSxJXRwAAcBphZNLU6r30jQAA4CTCyKTIXCNcGQEAwFGEkUmlDO8FAMAIwsikqeG9hBEAAJxEGJkUnviMuUYAAHAWYWRSuGeE2zQAADiLMDKJ9WkAADCDMDKptJCeEQAATEgpjLS1tam6ulp5eXmqra1VV1dX3H3/67/+S5ZlzXr88pe/TLnoTAhfGXn7dECB8ZDhagAAyB5Jh5GOjg41Nzdrz5496unp0ZYtW7R161b19vbOedyrr76q/v7+yGPt2rUpF50J713mkTvHkiQNjXJ1BAAApyQdRvbt26cbbrhBN954ozZs2KD9+/ersrJSBw4cmPO4srIyrVq1KvJwuVwpF50JOTlW5FYNfSMAADgnqTASCATU3d2thoaGqO0NDQ06cuTInMdedNFFWr16tS6//HI99dRTc+7r9/vl8/miHk6gbwQAAOclFUaGhoYUDAZVXl4etb28vFwDAwMxj1m9erXa29v16KOP6rHHHtO6det0+eWX6+mnn477Pq2trSouLo48KisrkykzZZH1aZhrBAAAx7hTOciyrKjntm3P2ha2bt06rVu3LvK8vr5efX19+upXv6qPfvSjMY/ZvXu3WlpaIs99Pp8jgaQ0vD6NjysjAAA4JakrIyUlJXK5XLOuggwODs66WjKXj3zkI3r99dfjvu71elVUVBT1cEJkrhEaWAEAcExSYcTj8ai2tladnZ1R2zs7O7V58+aEz9PT06PVq1cn89aOiPSMcGUEAADHJH2bpqWlRY2Njaqrq1N9fb3a29vV29urpqYmSRO3WE6cOKGDBw9Kkvbv36/zzz9fGzduVCAQ0He/+109+uijevTRR9P7SdJgahZWekYAAHBK0mFk27ZtOnXqlPbu3av+/n7V1NTo0KFDqqqqkiT19/dHzTkSCAT0N3/zNzpx4oTy8/O1ceNG/cd//Ieuuuqq9H2KNCkrmuwZYTQNAACOsWzbtk0XMR+fz6fi4mINDw9ntH/kxO/e1SX3/ES5Lkuv/r+tysmJ3ZQLAADml+jvb9ammaZ0+cRtmrGgrd+9O2a4GgAAsgNhZBqPO0fvXZYriblGAABwCmFkhrLJuUaYEh4AAGcQRmZgeC8AAM4ijMxQxvo0AAA4ijAyQ2kR69MAAOAkwsgM9IwAAOAswsgMpdymAQDAUYSRGaamhCeMAADgBMLIDJEGVh89IwAAOIEwMkN4fZrTgaBO+8cNVwMAwLmPMDJDgcel/FyXJG7VAADgBMLIDJZlqayIJlYAAJxCGIlhauIz+kYAAMg0wkgMzDUCAIBzCCMxMNcIAADOIYzEwGJ5AAA4hzASAz0jAAA4hzASQ3iuEXpGAADIPMJIDKXLmRIeAACnEEZiCM8zcup0QGPBkOFqAAA4txFGYlixzCNXjiVJOjUaMFwNAADnNsJIDDk5lkqWeyTRxAoAQKYRRuIIT3zG8F4AADKLMBJHGROfAQDgCMJIHOEmVkbUAACQWYSROMLDe+kZAQAgswgjcZROTnzGbRoAADKLMBIHPSMAADiDMBJHOIwMEUYAAMgowkgc4ZV7T474Zdu24WoAADh3EUbiCIeRQDCk350ZM1wNAADnLsJIHF63S+9ZliuJvhEAADKJMDKHskLmGgEAINMII3MoLWSuEQAAMo0wMofI+jRcGQEAIGMII3OIzDXCYnkAAGQMYWQOkeG9o4QRAAAyhTAyh0jPiI+eEQAAMoUwModwzwijaQAAyBzCyBzKihjaCwBAphFG5hC+TTPiH9e7gaDhagAAODcRRuZQ6HUrL3fiK2KuEQAAMoMwMgfLsphrBACADCOMzIMp4QEAyCzCyDwY3gsAQGYRRuYRmYWVKyMAAGQEYWQeZUX0jAAAkEmEkXmU0jMCAEBGEUbmUcptGgAAMoowMo+p0TQ0sAIAkAkphZG2tjZVV1crLy9PtbW16urqSui4//7v/5bb7daHP/zhVN7WiPA8I6dOBzQeDBmuBgCAc0/SYaSjo0PNzc3as2ePenp6tGXLFm3dulW9vb1zHjc8PKwdO3bo8ssvT7lYE1YUeOTKsWTbE4EEAACkV9JhZN++fbrhhht04403asOGDdq/f78qKyt14MCBOY+76aabtH37dtXX16dcrAmuHEsrCzySpEEffSMAAKRbUmEkEAiou7tbDQ0NUdsbGhp05MiRuMd961vf0v/+7//qzjvvTOh9/H6/fD5f1MOk8Oq9rE8DAED6JRVGhoaGFAwGVV5eHrW9vLxcAwMDMY95/fXXtWvXLv3zP/+z3G53Qu/T2tqq4uLiyKOysjKZMtOO9WkAAMiclBpYLcuKem7b9qxtkhQMBrV9+3b93d/9nd7//vcnfP7du3dreHg48ujr60ulzLRhfRoAADInsUsVk0pKSuRyuWZdBRkcHJx1tUSSRkZG9MILL6inp0c333yzJCkUCsm2bbndbj355JP6+Mc/Pus4r9crr9ebTGkZNTXXCLdpAABIt6SujHg8HtXW1qqzszNqe2dnpzZv3jxr/6KiIr300ks6duxY5NHU1KR169bp2LFjuvjiixdWvUMi69PQwAoAQNoldWVEklpaWtTY2Ki6ujrV19ervb1dvb29ampqkjRxi+XEiRM6ePCgcnJyVFNTE3V8WVmZ8vLyZm1fzErpGQEAIGOSDiPbtm3TqVOntHfvXvX396umpkaHDh1SVVWVJKm/v3/eOUeWmvBoGnpGAABIP8u2bdt0EfPx+XwqLi7W8PCwioqKHH//vrfPaMu9T8njytGrf/9/YjbrAgCAaIn+/mZtmgSEG1gDwZCG3x0zXA0AAOcWwkgC8nJdKs7PlcStGgAA0o0wkqCp4b2EEQAA0okwkqAy5hoBACAjCCMJYq4RAAAygzCSoLKiiblG6BkBACC9CCMJKl1OzwgAAJlAGElQeOIzekYAAEgvwkiCGE0DAEBmEEYSVFZIzwgAAJlAGElQ+MrIyNlxnR0LGq4GAIBzB2EkQUV5bnndE18Xw3sBAEgfwkiCLMuiiRUAgAwgjCSBvhEAANKPMJIE5hoBACD9CCNJ4DYNAADpRxhJAuvTAACQfoSRJER6RkYJIwAApAthJAmlXBkBACDtCCNJYEp4AADSjzCShHAD69un/QqGbMPVAABwbiCMJGFlgVc5lhSypVP0jQAAkBaEkSS4ciytZK4RAADSijCSpMjwXuYaAQAgLQgjSQqHEaaEBwAgPQgjSWJ4LwAA6UUYSVJ44jN6RgAASA/CSJJYnwYAgPQijCSJnhEAANKLMJIkZmEFACC9CCNJmt4zYtvMwgoAwEIRRpIUvjISGA/J9+644WoAAFj6CCNJyst1qSjPLUk6OUoTKwAAC0UYSQFzjQAAkD6EkRQw1wgAAOlDGEkBc40AAJA+hJEUMNcIAADpQxhJAXONAACQPoSRFER6RmhgBQBgwQgjKSgrpGcEAIB0IYykINzASs8IAAALRxhJQenyids0vrPjOjsWNFwNAABLG2EkBUX5bnncE18dV0cAAFgYwkgKLMua1jdCGAEAYCEIIykqjcw1QhMrAAALQRhJEVdGAABID8JIiphrBACA9CCMpIgp4QEASA/CSIpKmfgMAIC0IIykaGrlXq6MAACwECmFkba2NlVXVysvL0+1tbXq6uqKu+8zzzyjSy65RCtXrlR+fr7Wr1+vr3/96ykXvFhEekYIIwAALIg72QM6OjrU3NystrY2XXLJJXrooYe0detWvfzyy1qzZs2s/QsKCnTzzTfrgx/8oAoKCvTMM8/opptuUkFBgf7qr/4qLR/ChHDPyKlRv4IhW64cy3BFAAAsTZZt23YyB1x88cXatGmTDhw4ENm2YcMGXXvttWptbU3oHJ/4xCdUUFCg73znOwnt7/P5VFxcrOHhYRUVFSVTbsaMB0Na+8Ufyral5/ZcHrlSAgAAJiT6+zup2zSBQEDd3d1qaGiI2t7Q0KAjR44kdI6enh4dOXJEl112WTJvvei4XTlaWTDZN8LwXgAAUpbUbZqhoSEFg0GVl5dHbS8vL9fAwMCcx5533nk6efKkxsfHddddd+nGG2+Mu6/f75ffP/UL3ufzJVOmY8oKvRoa9TO8FwCABUipgdWyovsjbNuetW2mrq4uvfDCC3rwwQe1f/9+PfLII3H3bW1tVXFxceRRWVmZSpkZFx5RQxgBACB1SV0ZKSkpkcvlmnUVZHBwcNbVkpmqq6slSRdeeKF++9vf6q677tKnP/3pmPvu3r1bLS0tkec+n29RBpLS5cw1AgDAQiV1ZcTj8ai2tladnZ1R2zs7O7V58+aEz2PbdtRtmJm8Xq+KioqiHosRc40AALBwSQ/tbWlpUWNjo+rq6lRfX6/29nb19vaqqalJ0sRVjRMnTujgwYOSpAceeEBr1qzR+vXrJU3MO/LVr35Vn//859P4McxgfRoAABYu6TCybds2nTp1Snv37lV/f79qamp06NAhVVVVSZL6+/vV29sb2T8UCmn37t06fvy43G63fu/3fk/33HOPbrrppvR9CkMi69OMEkYAAEhV0vOMmLAY5xmRpBfeeFuffPBZVa7IV9f//bjpcgAAWFQyMs8Iok2/TbMEMh0AAIsSYWQBwiv3+sdDGvGPG64GAICliTCyAPkelwq9E203NLECAJAawsgClRYx1wgAAAtBGFmgyIga5hoBACAlhJEFCjexEkYAAEgNYWSBwk2szMIKAEBqCCMLFL5NM+ijZwQAgFQQRhaI9WkAAFgYwsgC0TMCAMDCEEYWiJ4RAAAWhjCyQOGekeF3x3R2LGi4GgAAlh7CyAIV5+fK4574GrlVAwBA8ggjC2RZlkqXT058NkoYAQAgWYSRNIj0jbA+DQAASSOMpMHUlPDMNQIAQLIII2nAXCMAAKSOMJIGzDUCAEDqCCNpwFwjAACkjjCSBpH1aegZAQAgaYSRNOA2DQAAqSOMpEG4gXVoNKBgyDZcDQAASwthJA1WFnhkWVIwZOvt0wHT5QAAsKQQRtLA7crRygKPJPpGAABIFmEkTUrpGwEAICWEkTRheC8AAKkhjKTJ1JTwhBEAAJJBGEmTyFwjPnpGAABIBmEkTSJXRka5MgIAQDIII2kSbmAd9BFGAABIBmEkTVi5FwCA1BBG0mT6+jS2zSysAAAkijCSJuGhvWfHQhr1jxuuBgCApYMwkibLPG4t97olcasGAIBkEEbSaGp4L2EEAIBEEUbSqHRa3wgAAEgMYSSNyopYnwYAgGQRRtKodDlTwgMAkCzCSBox1wgAAMkjjKRRGT0jAAAkjTCSRmWF9IwAAJAswkgaTY2mIYwAAJAowkgahW/T/O7MmPzjQcPVAACwNBBG0ug9y3KV67IkSUOjAcPVAACwNBBG0siyrMjw3kEfTawAACSCMJJmpZMTn9E3AgBAYggjaVZGEysAAEkhjKRZOIwwvBcAgMQQRtKsNBJG6BkBACARhJE0C098NujjyggAAIkgjKQZPSMAACQnpTDS1tam6upq5eXlqba2Vl1dXXH3feyxx3TllVeqtLRURUVFqq+v149+9KOUC17swovl0TMCAEBikg4jHR0dam5u1p49e9TT06MtW7Zo69at6u3tjbn/008/rSuvvFKHDh1Sd3e3Pvaxj+maa65RT0/PgotfjMI9I0OjfoVCtuFqAABY/CzbtpP6jXnxxRdr06ZNOnDgQGTbhg0bdO2116q1tTWhc2zcuFHbtm3T3/7t3ya0v8/nU3FxsYaHh1VUVJRMuY4bC4b0/i/+ULYtvfDFK1QyOQkaAADZJtHf30ldGQkEAuru7lZDQ0PU9oaGBh05ciShc4RCIY2MjGjFihVx9/H7/fL5fFGPpSLXlaMVyzySaGIFACARSYWRoaEhBYNBlZeXR20vLy/XwMBAQuf42te+ptOnT+u6666Lu09ra6uKi4sjj8rKymTKNC4yvHeUMAIAwHxSamC1LCvquW3bs7bF8sgjj+iuu+5SR0eHysrK4u63e/duDQ8PRx59fX2plGlMOIywPg0AAPNzJ7NzSUmJXC7XrKsgg4ODs66WzNTR0aEbbrhB//Iv/6Irrrhizn29Xq+83qXbaxGZa4QRNQAAzCupKyMej0e1tbXq7OyM2t7Z2anNmzfHPe6RRx7RX/7lX+rhhx/W1VdfnVqlSwjDewEASFxSV0YkqaWlRY2Njaqrq1N9fb3a29vV29urpqYmSRO3WE6cOKGDBw9KmggiO3bs0H333aePfOQjkasq+fn5Ki4uTuNHWTxYnwYAgMQlHUa2bdumU6dOae/everv71dNTY0OHTqkqqoqSVJ/f3/UnCMPPfSQxsfHtXPnTu3cuTOy/frrr9e3v/3thX+CRSjSM8L6NAAAzCvpeUZMWErzjEjSc8ff1nUPPauqlct0+PaPmS4HAAAjMjLPCBLDbRoAABJHGMmA8G2aM4GgRv3jhqsBAGBxI4xkQIHXrQKPSxJzjQAAMB/CSIaUFTHXCAAAiSCMZEgpfSMAACSEMJIhU8N7CSMAAMyFMJIhZcw1AgBAQggjGRJen+akjysjAADMhTCSIZG5RkYJIwAAzIUwkiGRnhGujAAAMCfCSIaEV+6lZwQAgLkRRjIk3DPyzpkxBcZDhqsBAGDxIoxkyHvyc5XrsiRJQ/SNAAAQF2EkQ3JyLJUsZ64RAADmQxjJoMhcI6xPAwBAXISRDCotZH0aAADmQxjJoPCIGtanAQAgPsJIBpXSMwIAwLwIIxk0dWWEnhEAAOIhjGRQZH0arowAABAXYSSDplbuJYwAABAPYSSDwuvTnBzxKxSyDVcDAMDiRBjJoPCkZ+MhW++cCRiuBgCAxYkwkkEed45WFHgkSSeZEh4AgJgIIxkWGd7rI4wAABALYSTDwsN7aWIFACA2wkiGlUZG1DDXCAAAsRBGMoy5RgAAmBthJMNKmWsEAIA5EUYyLDzx2UkaWAEAiIkwkmFl9IwAADAnwkiGlRXRMwIAwFwIIxkW7hk5HQjqtH/ccDUAACw+hJEMW+51a5nHJYkmVgAAYiGMOCDSN+KjbwQAgJkIIw6IzDXC+jQAAMxCGHFAZK4RhvcCADALYcQBTHwGAEB8hBEHTC2WR88IAAAzEUYcwPo0AADERxhxQPg2DWEEAIDZCCMOKKNnBACAuAgjDgiHkbdPBzQWDBmuBgCAxYUw4oD3LvPInWNJkoaYawQAgCiEEQfk5FgqWc5cIwAAxEIYccjU8F7CCAAA0xFGHFLGiBoAAGIijDhkahZWJj4DAGA6wohDSicnPuM2DQAA0QgjDiljsTwAAGJKKYy0tbWpurpaeXl5qq2tVVdXV9x9+/v7tX37dq1bt045OTlqbm5OtdYlLdIzwtBeAACiJB1GOjo61NzcrD179qinp0dbtmzR1q1b1dvbG3N/v9+v0tJS7dmzRx/60IcWXPBSFZkS3kfPCAAA0yUdRvbt26cbbrhBN954ozZs2KD9+/ersrJSBw4ciLn/+eefr/vuu087duxQcXHxggteqsqKJhfLG/XLtm3D1QAAsHgkFUYCgYC6u7vV0NAQtb2hoUFHjhxJa2HnmtLJSc/GgrbeOTNmuBoAABYPdzI7Dw0NKRgMqry8PGp7eXm5BgYG0laU3++X3z/VW+Hz+dJ2blM87hy9d1mu3jkzppMjfq0o8JguCQCARSGlBlbLsqKe27Y9a9tCtLa2qri4OPKorKxM27lNYq4RAABmSyqMlJSUyOVyzboKMjg4OOtqyULs3r1bw8PDkUdfX1/azm1SWXiuEYb3AgAQkVQY8Xg8qq2tVWdnZ9T2zs5Obd68OW1Feb1eFRUVRT3OBZG5Rpj4DACAiKR6RiSppaVFjY2NqqurU319vdrb29Xb26umpiZJE1c1Tpw4oYMHD0aOOXbsmCRpdHRUJ0+e1LFjx+TxePSBD3wgPZ9iiSgtYn0aAABmSjqMbNu2TadOndLevXvV39+vmpoaHTp0SFVVVZImJjmbOefIRRddFPlzd3e3Hn74YVVVVemNN95YWPVLTHhEDT0jAABMSTqMSNLnPvc5fe5zn4v52re//e1Z25hXY0J4rhFu0wAAMIW1aRwU7hkZIowAABBBGHEQDawAAMxGGHFQeJ6RUf+4zgTGDVcDAMDiQBhx0HKvW/m5LknMNQIAQBhhxEGWZaksPLx3lDACAIBEGHFcZHgvV0YAAJBEGHFc+MoIc40AADCBMOKwyPo0jKgBAEASYcRx4RE1TAkPAMAEwojDSplrBACAKIQRh0UmPvPRMwIAgEQYcVy4Z4TbNAAATCCMOCw8mubtMwGNBUOGqwEAwDzCiMNWLPPIlWPJtqVTowHT5QAAYBxhxGE5OZZKlnskMdcIAAASYcSIyFwjzMIKAABhxITwiBrWpwEAgDBiRGSuEa6MAABAGDEhMtcIPSMAABBGTCgtYn0aAADCCCMGlLE+DQAAEYQRA1gsDwCAKYQRA6ZfGbFt23A1AACYRRgxIHxlJBAMafjdMcPVAABgFmHEAK/bpeL8XEk0sQIAQBgxpIy5RgAAkEQYMSa8ei9zjQAAsh1hxJDw+jSMqAEAZDvCiCGRKeEJIwCALEcYMaSMMAIAgCTCiDFTi+XRMwIAyG6EEUMiPSOjXBkBAGQ3woghkSnhGdoLAMhyhBFDwkN7R/zjejcQNFwNAADmEEYMKfS6lZc78fUz1wgAIJsRRgyxLIu5RgAAEGHEKOYaAQCAMGJUGcN7AQAgjJgUDiOv9I/oV4Oj+q3vrE77x2XbtuHKAABwjtt0AdmsrGiiZ6TjhT51vNAX2W5Z0nKPW8vz3FrudavA61bh5J9nPp/55+UznnvdObIsy9RHBABgXoQRg66+cLV+8stBDQyf1ah/XKP+cQVDtmx7YsjviH98we+R67LiBpXCPLcKPG553DnKsSzlWBONtZalqOfhP+fMeC0nx5p8XbIUY5+c8POp12ae05o8dvL/Jt5fU9utye2a8Txc5+RLk3+e8fqM82ja88ifZ503/M1Ne+8Z55t8ddq+4XNZUfVEvU/4BUW/X7xzy9KMz2dN2z/O+894v1n1EEoBLFKWvQTuCfh8PhUXF2t4eFhFRUWmy8kY27Z1diykEf+YTvuDGj07PvVn/9jk83Gd9o9H/3nm87PjOs3cJZhDrDAU3i5NBZsZ/zPrdWvW61bU89nHJ3jctH0U971m1DrHeWe9nsSxUWeY4/yJ1DX1PP7njQqZMz571DFzvHcqNWpmjTFqm1lXzNdi1B9v3+jXor+T+f4+wv8BEnvf6PeP+/duzf09xTp33H3jfO7YfycJ7BvrH16c8yV8zqj9Z/9lfLL2PNW8r3jW9oVI9Pc3V0YWEcuylO9xKd/jkgoXdq5gyNaZQOygMv3Po/5xjQVthWxbtm0rZEuhyf+deD61TdNem9g//Hz6/jGOD829v62J1ybfItIzY097TZp6zZ58Mv15eN/IceFjZrw+9R7hb8qe/b6Rc0w9D7/fzHNOniGqRs2sJca5TbOnfZ7YBS2CIgE4alPVe9MeRhJFGDlHuXIsFeblqjAvVzLzbwsJsG07KhjMFXSkeQJa+LgEwlDkiKj3nnrP6OdTtU3frmSPm7G/4u4f53xxts9XRyq1xKxn2ouzj4/+ThPZL5H6pj+Jedw833X0/vN/xlmfJeZr088dP7TOdc5Zdcz3Oeb9Nzb364rz9z739zW7nplmvm+84xPZN/q8iZ8v0fef/kK886wtWx67IAcQRgCDpve+TG4xVQoAGMPQXgAAYBRhBAAAGEUYAQAARhFGAACAUYQRAABgFGEEAAAYlVIYaWtrU3V1tfLy8lRbW6uurq459z98+LBqa2uVl5enCy64QA8++GBKxQIAgHNP0mGko6NDzc3N2rNnj3p6erRlyxZt3bpVvb29Mfc/fvy4rrrqKm3ZskU9PT36whe+oFtuuUWPPvrogosHAABLX9Jr01x88cXatGmTDhw4ENm2YcMGXXvttWptbZ21/x133KEnnnhCr7zySmRbU1OTXnzxRT377LMJvWe2rE0DAMC5JNHf30ldGQkEAuru7lZDQ0PU9oaGBh05ciTmMc8+++ys/f/oj/5IL7zwgsbGxmIe4/f75fP5oh4AAODclFQYGRoaUjAYVHl5edT28vJyDQwMxDxmYGAg5v7j4+MaGhqKeUxra6uKi4sjj8rKymTKBAAAS0hKDawzlx62bTvmcsRz7R9re9ju3bs1PDwcefT19aVSJgAAWAKSWiivpKRELpdr1lWQwcHBWVc/wlatWhVzf7fbrZUrV8Y8xuv1yuv1JlMaAABYopIKIx6PR7W1ters7NSf/umfRrZ3dnbqT/7kT2IeU19frx/84AdR25588knV1dUpNzc3ofcNX0mhdwQAgKUj/Ht73rEydpK+973v2bm5ufY3v/lN++WXX7abm5vtgoIC+4033rBt27Z37dplNzY2Rvb/9a9/bS9btsy+7bbb7Jdfftn+5je/aefm5tr/+q//mvB79vX12ZJ48ODBgwcPHkvw0dfXN+fv+aSujEjStm3bdOrUKe3du1f9/f2qqanRoUOHVFVVJUnq7++PmnOkurpahw4d0m233aYHHnhAFRUV+sY3vqE/+7M/S/g9Kyoq1NfXp8LCwjl7U5Ll8/lUWVmpvr6+rB0ynO3fQbZ/fonvgM+f3Z9f4jvI5Oe3bVsjIyOqqKiYc7+k5xk5lzB/Cd9Btn9+ie+Az5/dn1/iO1gMn5+1aQAAgFGEEQAAYFRWhxGv16s777wzq4cRZ/t3kO2fX+I74PNn9+eX+A4Ww+fP6p4RAABgXlZfGQEAAOYRRgAAgFGEEQAAYBRhBAAAGJXVYaStrU3V1dXKy8tTbW2turq6TJfkiNbWVv3+7/++CgsLVVZWpmuvvVavvvqq6bKMaW1tlWVZam5uNl2Ko06cOKHPfOYzWrlypZYtW6YPf/jD6u7uNl2WY8bHx/XFL35R1dXVys/P1wUXXKC9e/cqFAqZLi0jnn76aV1zzTWqqKiQZVl6/PHHo163bVt33XWXKioqlJ+frz/8wz/UL37xCzPFZshc38HY2JjuuOMOXXjhhSooKFBFRYV27Niht956y1zBaTbfv4HpbrrpJlmWpf379ztSW9aGkY6ODjU3N2vPnj3q6enRli1btHXr1qip7M9Vhw8f1s6dO/XTn/5UnZ2dGh8fV0NDg06fPm26NMc9//zzam9v1wc/+EHTpTjqnXfe0SWXXKLc3Fz98Ic/1Msvv6yvfe1res973mO6NMd85Stf0YMPPqj7779fr7zyiu699179wz/8g/7xH//RdGkZcfr0aX3oQx/S/fffH/P1e++9V/v27dP999+v559/XqtWrdKVV16pkZERhyvNnLm+gzNnzujo0aP60pe+pKNHj+qxxx7Ta6+9pj/+4z82UGlmzPdvIOzxxx/X//zP/8w7hXtaJbtQ3rniD/7gD+ympqaobevXr7d37dplqCJzBgcHbUn24cOHTZfiqJGREXvt2rV2Z2enfdlll9m33nqr6ZIcc8cdd9iXXnqp6TKMuvrqq+3PfvazUds+8YlP2J/5zGcMVeQcSfb3v//9yPNQKGSvWrXKvueeeyLbzp49axcXF9sPPviggQozb+Z3EMtzzz1nS7LffPNNZ4pyULzP/5vf/MZ+3/veZ//85z+3q6qq7K9//euO1JOVV0YCgYC6u7vV0NAQtb2hoUFHjhwxVJU5w8PDkqQVK1YYrsRZO3fu1NVXX60rrrjCdCmOe+KJJ1RXV6c///M/V1lZmS666CL90z/9k+myHHXppZfqxz/+sV577TVJ0osvvqhnnnlGV111leHKnHf8+HENDAxE/Uz0er267LLLsvJnYtjw8LAsy8qaK4ahUEiNjY26/fbbtXHjRkffO+lVe88FQ0NDCgaDKi8vj9peXl6ugYEBQ1WZYdu2WlpadOmll6qmpsZ0OY753ve+p6NHj+r55583XYoRv/71r3XgwAG1tLToC1/4gp577jndcsst8nq92rFjh+nyHHHHHXdoeHhY69evl8vlUjAY1Je//GV9+tOfNl2a48I/92L9THzzzTdNlGTc2bNntWvXLm3fvj1rFs/7yle+IrfbrVtuucXx987KMBJmWVbUc9u2Z207191888362c9+pmeeecZ0KY7p6+vTrbfeqieffFJ5eXmmyzEiFAqprq5Od999tyTpoosu0i9+8QsdOHAga8JIR0eHvvvd7+rhhx/Wxo0bdezYMTU3N6uiokLXX3+96fKM4GfihLGxMX3qU59SKBRSW1ub6XIc0d3drfvuu09Hjx418neelbdpSkpK5HK5Zl0FGRwcnPVfBueyz3/+83riiSf01FNP6bzzzjNdjmO6u7s1ODio2tpaud1uud1uHT58WN/4xjfkdrsVDAZNl5hxq1ev1gc+8IGobRs2bMiKBu6w22+/Xbt27dKnPvUpXXjhhWpsbNRtt92m1tZW06U5btWqVZKU9T8TpYkgct111+n48ePq7OzMmqsiXV1dGhwc1Jo1ayI/F99880399V//tc4///yMv39WhhGPx6Pa2lp1dnZGbe/s7NTmzZsNVeUc27Z1880367HHHtNPfvITVVdXmy7JUZdffrleeuklHTt2LPKoq6vTX/zFX+jYsWNyuVymS8y4Sy65ZNZw7tdee01VVVWGKnLemTNnlJMT/SPQ5XKds0N751JdXa1Vq1ZF/UwMBAI6fPhwVvxMDAsHkddff13/+Z//qZUrV5ouyTGNjY362c9+FvVzsaKiQrfffrt+9KMfZfz9s/Y2TUtLixobG1VXV6f6+nq1t7ert7dXTU1NpkvLuJ07d+rhhx/Wv/3bv6mwsDDyX0PFxcXKz883XF3mFRYWzuqPKSgo0MqVK7Omb+a2227T5s2bdffdd+u6667Tc889p/b2drW3t5suzTHXXHONvvzlL2vNmjXauHGjenp6tG/fPn32s581XVpGjI6O6le/+lXk+fHjx3Xs2DGtWLFCa9asUXNzs+6++26tXbtWa9eu1d13361ly5Zp+/btBqtOr7m+g4qKCn3yk5/U0aNH9e///u8KBoORn40rVqyQx+MxVXbazPdvYGb4ys3N1apVq7Ru3brMF+fImJ1F6oEHHrCrqqpsj8djb9q0KWuGtkqK+fjWt75lujRjsm1or23b9g9+8AO7pqbG9nq99vr16+329nbTJTnK5/PZt956q71mzRo7Ly/PvuCCC+w9e/bYfr/fdGkZ8dRTT8X8//vrr7/etu2J4b133nmnvWrVKtvr9dof/ehH7Zdeesls0Wk213dw/PjxuD8bn3rqKdOlp8V8/wZmcnJor2Xbtp35yAMAABBbVvaMAACAxYMwAgAAjCKMAAAAowgjAADAKMIIAAAwijACAACMIowAAACjCCMAAMAowggAADCKMAIAAIwijAAAAKMIIwAAwKj/Dzuq9JqvOyTMAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(bm.as_numpy(trainer.get_hist_metric()))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "359cf219", "metadata": { "lines_to_next_cell": 2 }, "source": [ "## Testing" ] }, { "cell_type": "code", "execution_count": 9, "id": "24dae0a4", "metadata": { "ExecuteTime": { "end_time": "2023-07-22T10:12:26.148354200Z", "start_time": "2023-07-22T10:12:25.663301600Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42560aaa4bda42d1854b0fa17f23155d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/25 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model.reset_state(1)\n", "x, y = build_inputs_and_targets(batch_size=1)\n", "predicts = trainer.predict(x)\n", "\n", "plt.figure(figsize=(8, 2))\n", "plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')\n", "plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "encoding": "# -*- coding: utf-8 -*-", "formats": "ipynb,auto:percent", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "brainpy", "language": "python", "name": "brainpy" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }