State Saving and Loading#

Colab Open in Kaggle

Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.

import numpy as np

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Saving and loading variables#

State saving and loading in BrainPy are managed by a local function and a global function.

The local function is to save or load states in the current node. Particularly, save_state() and load_state() are local functions for saving and loading states.

The global function is to save or load all states in the current and children nodes. Particularly, brainpy.save_state() and brainpy.load_state() are global functions for saving and loading states.

Here’s a simple example:

class SNN(bp.DynamicalSystem):
  def __init__(self):
    super().__init__()
    self.var = bm.Variable(bm.zeros(1))
    self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None)
    self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=2.0, spk_fun=bm.surrogate.Arctan())

  def update(self, x):
    return x >> self.l1 >> self.l2
net = SNN()

State saving#

To extract the local variables in the net:

net.save_state()
{'SNN0.var': Array([0.], dtype=float32)}

To extract all variable under the net (including the local variables in the sub-nodes):

bp.save_state(net)
{'SNN0': {'SNN0.var': Array([0.], dtype=float32)},
 'Dense0': {},
 'Lif0': {'Lif0.V': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
  'Lif0.spike': Array([False, False, False, False, False, False, False, False, False,
         False], dtype=bool)}}

If we want to save states of a model onto the disk, we can use brainpy.checkpoints.save_pytree.

bp.checkpoints.save_pytree('a.bp', net.state_dict())
Saving checkpoint into a.bp

State loading#

To retrieve the saved states in the disk, one can use brainpy.checkpoints.load_pytree.

states = bp.checkpoints.load_pytree('a.bp')
Loading checkpoint from a.bp
states
{'SNN0': {'SNN0.var': array([0.], dtype=float32)},
 'Dense0': {},
 'Lif0': {'Lif0.V': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
  'Lif0.spike': array([False, False, False, False, False, False, False, False, False,
         False])},
 'ExponentialEuler0': {}}

After loading the model onto the memory, we can assign the loaded states to the corresponding variable by using load_state_dict() function.

bp.load_state(net, states)
StateLoadResult(missing_keys=[], unexpected_keys=[])
  • brainpy.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True) function requires you to provide a filename which is the path where checkpoint files will be stored. You also need to supply a target, which is a state dict object. An optional overwrite argument allows you to decide whether to overwrite existing checkpoint files if a checkpoint for the current step or a later one already exists. If you provide an async_manager, the save operation will be non-blocking on the main thread, but note that this is only suitable for a single host. However, any ongoing save will still prevent new saves to ensure overwrite logic remains correct. Finally, you can set the verbose argument to specify if you want to receive printed information about the operation.

  • brainpy.checkpoints.load_pytree(filename: str, parallel: bool = True) function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with the filename argument. If you set the parallel argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the input target without changes. If you specify a file path that doesn’t exist, the function will also return the original target. This behavior mirrors the scenario where a directory path is given, but the directory hasn’t been created yet.

  • brainpy.save_state(target) function retrieves the entire state of the target module and returns it as a dictionary.

  • brainpy.load_state(target, state_dict) function is used to import parameters and buffers from a provided state_dict into the current module and all its child modules. You need to provide the function with a state_dict, which is a dictionary containing the desired parameters and persistent buffers to be loaded. hen executed, the function returns a StateLoadResult, a named tuple with two fields:

    • missing_keys: A list of keys that are present in the module but missing in the provided state_dict.

    • unexpected_keys: A list of keys found in the state_dict that don’t correspond to any part of the current module.

A simple example#

Here is a example of model saving and loading in BrainPy using bp.checkpoints.save_pytree and bp.checkpoints.load_pytree functions.

bm.set_dt(1.)

class SNN(bp.DynamicalSystem):
  def __init__(self, num_in, num_rec, num_out):
    super().__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out

    # neuron groups
    self.r = bp.dyn.Lif(num_rec, tau=10., V_reset=0., V_rest=0., V_th=1.)
    self.o = bp.dyn.Integrator(num_out, tau=5.)

    # synapse: i->r
    self.i2r = bp.Sequential(
        comm=bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=20.)),
        syn=bp.dyn.Expon(num_rec, tau=10.),
    )

    # synapse: r->o
    self.r2o = bp.Sequential(
        comm=bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=20.)),
        syn=bp.dyn.Expon(num_out, tau=10.),
    )

  def update(self, spike):
    return spike >> self.i2r >> self.r >> self.r2o >> self.o
num_in = 100
num_rec = 10
with bm.training_environment():
    # out task is a two label classification task
    net = SNN(num_in, num_rec, 2)  


# We try to use this simple task to classify a random spiking data into two classes. 
num_step = 100
num_sample = 256
freq = 10  # Hz
mask = bm.random.rand(num_step, num_sample, num_in)
x_data = bm.zeros((num_step, num_sample, num_in))
x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)
indices = bm.arange(num_step)


# training process
class Trainer:
  def __init__(self, net, opt):
    self.net = net
    self.opt = opt
    opt.register_train_vars(net.train_vars().unique())
    self.f_grad = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)
  
  @bm.cls_jit(inline=True)
  def f_loss(self):
    self.net.reset(num_sample)
    outs = bm.for_loop(self.net.step_run, (indices, x_data))
    return bp.losses.cross_entropy_loss(bm.max(outs, axis=0), y_data)

  @bm.cls_jit
  def f_train(self):
    grads, loss = self.f_grad()
    self.opt.update(grads)
    return loss


trainer = Trainer(net=net, opt=bp.optim.Adam(lr=4e-3))

loss = np.inf
for i in range(10):
  l = trainer.f_train()
  if l < loss:
    loss = l
    states = {'net': bp.save_state(net), # save the state dict of the network in the checkpoint
              'epoch_i': i,
              'train_loss': loss}
    bp.checkpoints.save_pytree('snn.bp', states, verbose=False) # save the checkpoint
    print(f'Epoch {i}, loss {loss}')
Epoch 0, loss 1.0733333826065063
Epoch 1, loss 0.9526105523109436
Epoch 2, loss 0.8582525253295898
Epoch 3, loss 0.7843770384788513
Epoch 4, loss 0.7399720549583435
Epoch 5, loss 0.7254235744476318
Epoch 9, loss 0.7122021913528442
# model loading
state_dict = bp.checkpoints.load_pytree('snn.bp') # load the state dict
bp.load_state(net, state_dict['net']) # unpack the state dict and load it into the network
Loading checkpoint from snn.bp
StateLoadResult(missing_keys=[], unexpected_keys=[])

Note

By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables.

Custom saving and loading#

You can make your own saving and loading functions easily.

For customizing the saving and loading, users can overwrite save_state and load_state functions.

Here is an example to customize:

class YourClass(bp.DynamicSystem):
  def __init__(self):
    self.a = 1
    self.b = bm.random.rand(10)
    self.c = bm.Variable(bm.random.rand(3))
    self.d = bm.var_list([bm.Variable(bm.random.rand(3)),
                         bm.Variable(bm.random.rand(3))])

  def save_state(self) -> dict:
    state_dict = {'a': self.a,
            'b': self.b,
            'c': self.c}
    for i, elem in enumerate(self.d):
      state_dict[f'd{i}'] = elem.value

    return state_dict

  def load_state(self, state_dict):
    self.a = state_dict['a']
    self.b = bm.asarray(state_dict['b'])
    self.c = bm.asarray(state_dict['c'])

    for i in range(len(self.d)):
      self.d[i].value = bm.asarray(state_dict[f'd{i}'])
  • save_state(self) function saves the state of the object’s variables and returns a dictionary where the keys are the names of the variables and the values are the variables’ contents.

  • load_state(self, state_dict: Dict) function loads the state of the object’s variables from a provided dictionary (state_dict). At firstly it gets the current variables of the object. Then, it determines the intersection of keys from the provided state_dict and the object’s variables. For each intersecting key, it updates the value of the object’s variable with the value from state_dict. Finally, returns A tuple containing two lists:

    • unexpected_keys: Keys in state_dict that were not found in the object’s variables.

    • missing_keys: Keys that are in the object’s variables but were not found in state_dict.