Monitor Every Multiple Steps

Monitor Every Multiple Steps#

Colab Open in Kaggle

Sometimes it is not necessary to record the system’s behavior at a very high temporal precision. When the simulation time is long, monitoring the variables at high temporal precision can lead to out of memory error. It is very helpful to record the values once every few steps to decrease the memory requirement.

In this tutorial, we will highlight how to record/monitor variable every multiple simulation time steps.

import brainpy as bp
import brainpy.math as bm
import numpy as np

First of all, define your dynamical system that you want. Here we use the EI balanced network model.

class EINet(bp.DynSysGroup):
  def __init__(self):
    super().__init__()
    self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                           V_initializer=bp.init.Normal(-55., 2.))
    self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
    self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
                                     syn=bp.dyn.Expon.desc(size=4000, tau=5.),
                                     out=bp.dyn.COBA.desc(E=0.),
                                     post=self.N)
    self.I = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
                                     syn=bp.dyn.Expon.desc(size=4000, tau=10.),
                                     out=bp.dyn.COBA.desc(E=-80.),
                                     post=self.N)

  def update(self, input):
    spk = self.delay.at('I')
    self.E(spk[:3200])
    self.I(spk[3200:])
    self.delay(self.N(input))
    return self.N.spike.value
 
  def run(self, ids, inputs):  # the most import function!!!
    for i, inp in zip(ids, inputs):
      bp.share.save(i=i, t=bm.get_dt() * i)
      self.update(inp)
    return self.N.spike.value

In this example, we monitor the spikes of the neuron group every 1 ms (10 time steps).

n_step_per_monitor = 10

brainpy.math.for_loop#

The key of using brainpy.math.for_loop for monitoring at multiple time steps is to reshape the running indices and inputs as the shape of [n_time, ..., n_step_per_time].

indices = np.arange(10000).reshape(-1, n_step_per_monitor)
inputs = np.ones(indices.shape) * 20.

Next, we write a run function, in which the model run multiple steps we want.


class EINet(bp.DynSysGroup):
  ...

  def run(self, ids, inputs):  
    for i, inp in zip(ids, inputs):   # run the model multiple steps in the run function
      bp.share.save(i=i, t=bm.get_dt() * i)
      self.update(inp)
    return self.N.spike.value

Finally, let’s run the model with brainpy.math.for_loop.

model = EINet()
spks = bm.for_loop(model.run, (indices, inputs), progress_bar=True)
spks.shape
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(1000, 4000)

The visualization will show what exactly we want.

bp.visualize.raster_plot(indices[:, 0], spks, show=True)
../_images/b2860e3fff478a85dc177b1fa5a38981f6af803d979b50ed75a9162cdd642932.png

brainpy.math.jit#

Another way for more flexible monitoring is using brainpy.math.jit.

From the above example, we see that the drawback of the multi-step monitoring is that it monitors all variables with the same time durations. However, sometimes, we try to monitor spikes at every time step, while monitoring membrane potential every ten time steps. For such scenario, brainpy.math.jit is the more suitable tool.

In this example, we directly use the jitted step function .jit_step_run.

indices = np.arange(10000)
inputs = np.ones(indices.shape) * 20.

model = EINet()

spks = []
mems = []
for i in indices:
    # run the model
    model.jit_step_run(i, inputs[i])
      
    # monitoring
    if i % n_step_per_monitor == 0:  # monitor membrane every ten steps
        mems.append(model.N.V.value)
    spks.append(model.N.spike.value)  # monitor spikes every time
    
spks = bm.as_numpy(spks)
mems = bm.as_numpy(mems)
bp.visualize.raster_plot(indices, spks, show=True)
../_images/af88b19eee8fd38953dc394971f46e8269ef6844ad63f5333dd37c30e05f4f38.png
bp.visualize.line_plot(indices[0::n_step_per_monitor], mems, show=True)
../_images/4227fa972b5c10d80cc1adc601e09856bc7c2b16f76904ea618165d8936c0c58.png