Saving and Loading#

@Chaoming Wang @Sichao He

Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Saving and loading variables#

Model saving and loading in BrainPy are implemented with bp.checkpoints.save_pytree and bp.checkpoints.load_pytree functions. And using .state_dict() and load_state_dict() functions to save and load the state of a model.

Here’s a simple example:

class SNN(bp.DynamicalSystem):
  def __init__(self, tau):
    super().__init__()
    self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None)
    self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=2.0, spk_fun=bm.surrogate.arctan)

  def update(self, x):
    return x >> self.l1 >> self.l2


net = SNN(2.0)
# model saving
for epoch_i in range(15):
  """
  training process...
  """
  if max_test_acc < test_acc:
    max_test_acc = test_acc
    states = {
      'net': net.state_dict(), # save the state dict of the network in the checkpoint
      'optimizer': optimizer.state_dict(),
      'epoch_i': epoch_i,
      'train_acc': train_acc,
      'test_acc': test_acc,
    }
    bp.checkpoints.save_pytree(os.path.join(out_dir, 'mnist-lif.bp'), states) # save the checkpoint
# model loading

state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'mnist-lif.bp')) # load the state dict
net.load_state_dict(state_dict['net']) # unpack the state dict and load it into the network
  • bp.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True) function requires you to provide a filename which is the path where checkpoint files will be stored. You also need to supply a target, which is a state dict object. An optional overwrite argument allows you to decide whether to overwrite existing checkpoint files if a checkpoint for the current step or a later one already exists. If you provide an async_manager, the save operation will be non-blocking on the main thread, but note that this is only suitable for a single host. However, any ongoing save will still prevent new saves to ensure overwrite logic remains correct. Finally, you can set the verbose argument to specify if you want to receive printed information about the operation.

  • bp.checkpoints.load_pytree(filename: str, parallel: bool = True) function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with the filename argument. If you set the parallel argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the input target without changes. If you specify a file path that doesn’t exist, the function will also return the original target. This behavior mirrors the scenario where a directory path is given, but the directory hasn’t been created yet.

  • .state_dict() function retrieves the entire state of the module and returns it as a dictionary.

  • .load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compatible: str = 'v2') function is used to import parameters and buffers from a provided state_dict into the current module and all its child modules. You need to provide the function with a state_dict, which is a dictionary containing the desired parameters and persistent buffers to be loaded. Optionally, you can also provide a warn parameter (defaulting to True) that will generate warnings if there are keys in the provided state_dict that either don’t match the current module’s structure (unexpected keys) or are missing from the state_dict but exist in the module (missing keys). When executed, the function returns a StateLoadResult, a named tuple with two fields:

    • missing_keys: A list of keys that are present in the module but missing in the provided state_dict.

    • unexpected_keys: A list of keys found in the state_dict that don’t correspond to any part of the current module.

Note

By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables.

Custom saving and loading#

You can make your own saving and loading functions easily.

For customizing the saving and loading, users can overwrite __save_state__ and __load_state__ functions.

Here is an example to customize:

class YourClass(bp.DynamicSystem):
  def __init__(self):
    self.a = 1
    self.b = bm.random.rand(10)
    self.c = bm.Variable(bm.random.rand(3))
    self.d = bm.var_list([bm.Variable(bm.random.rand(3)),
                         bm.Variable(bm.random.rand(3))])

  def __save_state__(self) -> dict:
    state_dict = {'a': self.a,
            'b': self.b,
            'c': self.c}
    for i, elem in enumerate(self.d):
      state_dict[f'd{i}'] = elem.value

    return state_dict

  def __load_state__(self, state_dict):
    self.a = state_dict['a']
    self.b = bm.asarray(state_dict['b'])
    self.c = bm.asarray(state_dict['c'])

    for i in range(len(self.d)):
      self.d[i].value = bm.asarray(state_dict[f'd{i}'])
  • __save_state__(self) function saves the state of the object’s variables and returns a dictionary where the keys are the names of the variables and the values are the variables’ contents.

  • __load_state__(self, state_dict: Dict) function loads the state of the object’s variables from a provided dictionary (state_dict). At firstly it gets the current variables of the object. Then, it determines the intersection of keys from the provided state_dict and the object’s variables. For each intersecting key, it updates the value of the object’s variable with the value from state_dict. Finally, returns A tuple containing two lists:

    • unexpected_keys: Keys in state_dict that were not found in the object’s variables.

    • missing_keys: Keys that are in the object’s variables but were not found in state_dict.