State Resetting

State Resetting#

Colab Open in Kaggle

State resetting is useful when simulating and training recurrent neural networks.

Similar to state saving and loading , state resetting is implemented with two functions:

  • a local function .reset_state() which resets all local variables in the current node.

  • a global function brainpy.reset_state() which resets all variables in parent and children nodes.

Let’s define a simple example:

import brainpy as bp
import brainpy.math as bm

class EINet(bp.DynSysGroup):
    def __init__(self):
      super().__init__()
      self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                                V_initializer=bp.init.Normal(-55., 2.))
      self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
      self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
                                     syn=bp.dyn.Expon(size=4000, tau=5.),
                                     out=bp.dyn.COBA(E=0.),
                                     post=self.N)
      self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
                                     syn=bp.dyn.Expon(size=4000, tau=10.),
                                     out=bp.dyn.COBA(E=-80.),
                                     post=self.N)

    def update(self, input):
      spk = self.delay.at('I')
      self.E(spk[:3200])
      self.I(spk[3200:])
      self.delay(self.N(input))
      return self.N.spike.value
net = EINet()

By calling brainpy.reset_state(net), we can reset all states in this network, including variables in the neurons, synapses, and networks. By using net.reset_state(), we can reset the local variables which are defined in the current network.

print('Before reset:', net.N.V.value)
bp.reset_state(net)
print('After reset:', net.N.V.value)
Before reset: [-57.487705 -51.873276 -56.49933  ... -58.255264 -54.304092 -54.878036]
After reset: [-52.170876 -57.16759  -53.589947 ... -55.548622 -55.703842 -53.661095]
print('Before reset_state:', net.N.V.value)
net.reset_state()
print('After reset_state:', net.N.V.value)
Before reset_state: [-52.170876 -57.16759  -53.589947 ... -55.548622 -55.703842 -53.661095]
After reset_state: [-52.170876 -57.16759  -53.589947 ... -55.548622 -55.703842 -53.661095]

There is no change for the V variable, meaning that the network’s reset_state() can not reset states in the children node. Instead, to reset the whole states of the network, users should use brainpy.reset_state() function.