Saving and Loading#
Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
Saving and loading variables#
Model saving and loading in BrainPy are implemented with bp.checkpoints.save_pytree
and bp.checkpoints.load_pytree
functions.
And using .state_dict()
and load_state_dict()
functions to save and load the state of a model.
Here’s a simple example:
class SNN(bp.DynamicalSystem):
def __init__(self, tau):
super().__init__()
self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None)
self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=2.0, spk_fun=bm.surrogate.arctan)
def update(self, x):
return x >> self.l1 >> self.l2
net = SNN(2.0)
# model saving
for epoch_i in range(15):
"""
training process...
"""
if max_test_acc < test_acc:
max_test_acc = test_acc
states = {
'net': net.state_dict(), # save the state dict of the network in the checkpoint
'optimizer': optimizer.state_dict(),
'epoch_i': epoch_i,
'train_acc': train_acc,
'test_acc': test_acc,
}
bp.checkpoints.save_pytree(os.path.join(out_dir, 'mnist-lif.bp'), states) # save the checkpoint
# model loading
state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'mnist-lif.bp')) # load the state dict
net.load_state_dict(state_dict['net']) # unpack the state dict and load it into the network
bp.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True)
function requires you to provide afilename
which is the path where checkpoint files will be stored. You also need to supply atarget
, which is a state dict object. An optionaloverwrite
argument allows you to decide whether to overwrite existing checkpoint files if a checkpoint for the current step or a later one already exists. If you provide anasync_manager
, the save operation will be non-blocking on the main thread, but note that this is only suitable for a single host. However, any ongoing save will still prevent new saves to ensure overwrite logic remains correct. Finally, you can set theverbose
argument to specify if you want to receive printed information about the operation.bp.checkpoints.load_pytree(filename: str, parallel: bool = True)
function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with thefilename
argument. If you set theparallel
argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the inputtarget
without changes. If you specify a file path that doesn’t exist, the function will also return the originaltarget
. This behavior mirrors the scenario where a directory path is given, but the directory hasn’t been created yet..state_dict()
function retrieves the entire state of the module and returns it as a dictionary..load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compatible: str = 'v2')
function is used to import parameters and buffers from a providedstate_dict
into the current module and all its child modules. You need to provide the function with astate_dict
, which is a dictionary containing the desired parameters and persistent buffers to be loaded. Optionally, you can also provide awarn
parameter (defaulting to True) that will generate warnings if there are keys in the providedstate_dict
that either don’t match the current module’s structure (unexpected keys) or are missing from thestate_dict
but exist in the module (missing keys). When executed, the function returns aStateLoadResult
, a named tuple with two fields:missing_keys: A list of keys that are present in the module but missing in the provided
state_dict
.unexpected_keys: A list of keys found in the
state_dict
that don’t correspond to any part of the current module.
Note
By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned ArrayCollector. Therefore, there will always be missing keys when loading the variables.
Custom saving and loading#
You can make your own saving and loading functions easily.
For customizing the saving and loading, users can overwrite __save_state__
and __load_state__
functions.
Here is an example to customize:
class YourClass(bp.DynamicSystem):
def __init__(self):
self.a = 1
self.b = bm.random.rand(10)
self.c = bm.Variable(bm.random.rand(3))
self.d = bm.var_list([bm.Variable(bm.random.rand(3)),
bm.Variable(bm.random.rand(3))])
def __save_state__(self) -> dict:
state_dict = {'a': self.a,
'b': self.b,
'c': self.c}
for i, elem in enumerate(self.d):
state_dict[f'd{i}'] = elem.value
return state_dict
def __load_state__(self, state_dict):
self.a = state_dict['a']
self.b = bm.asarray(state_dict['b'])
self.c = bm.asarray(state_dict['c'])
for i in range(len(self.d)):
self.d[i].value = bm.asarray(state_dict[f'd{i}'])
__save_state__(self)
function saves the state of the object’s variables and returns a dictionary where the keys are the names of the variables and the values are the variables’ contents.__load_state__(self, state_dict: Dict)
function loads the state of the object’s variables from a provided dictionary (state_dict
). At firstly it gets the current variables of the object. Then, it determines the intersection of keys from the provided state_dict and the object’s variables. For each intersecting key, it updates the value of the object’s variable with the value from state_dict. Finally, returns A tuple containing two lists:unexpected_keys
: Keys in state_dict that were not found in the object’s variables.missing_keys
: Keys that are in the object’s variables but were not found in state_dict.