Saving and Loading#

@Chaoming Wang

Being able to save and load the variables of a model is essential in brain dynamics programming. In this tutorial we describe how to save/load the variables in a model.

import brainpy as bp

bp.math.set_platform('cpu')

Saving and loading variables#

Model saving and loading in BrainPy are implemented with .save_states() and .load_states() functions.

BrainPy supports saving and loading model variables with various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

Here’s a simple example:

class EINet(bp.dyn.Network):
    def __init__(self, num_exc=3200, num_inh=800, method='exp_auto'):
        # neurons
        pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
        E = bp.models.LIF(num_exc, **pars, method=method)
        I = bp.models.LIF(num_inh, **pars, method=method)
        E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
        I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.

        # synapses
        E2E = bp.models.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        E2I = bp.models.ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02),
                                E=0., g_max=0.6, tau=5., method=method)
        I2E = bp.models.ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)
        I2I = bp.models.ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02),
                                E=-80., g_max=6.7, tau=10., method=method)

        super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
        
        
net = EINet()
import os
if not os.path.exists('./data'): 
    os.makedirs('./data')
# model saving

net.save_states('./data/net.h5')
# model loading

net.load_states('./data/net.h5')
  • .save_states(filename, all_var=None) function receives a string to specify the output file name. If all_vars is not provided, BrainPy will retieve all variables in the model though the relative path.

  • .load_states(filename, verbose, check_missing) function receives several arguments. The first is a string of the output file name. The second “verbose” specifies whether report the loading progress. The final argument “check_missing” will warn the variables of the model which missed in the output file.

# model loading with warning and checking

net.load_states('./data/net.h5', verbose=True)
WARNING:brainpy.base.io:There are variable states missed in ./data/net.h5. The missed variables are: ['ExpCOBA0.pre.V', 'ExpCOBA0.pre.input', 'ExpCOBA0.pre.refractory', 'ExpCOBA0.pre.spike', 'ExpCOBA0.pre.t_last_spike', 'ExpCOBA1.pre.V', 'ExpCOBA1.pre.input', 'ExpCOBA1.pre.refractory', 'ExpCOBA1.pre.spike', 'ExpCOBA1.pre.t_last_spike', 'ExpCOBA1.post.V', 'ExpCOBA1.post.input', 'ExpCOBA1.post.refractory', 'ExpCOBA1.post.spike', 'ExpCOBA1.post.t_last_spike', 'ExpCOBA2.pre.V', 'ExpCOBA2.pre.input', 'ExpCOBA2.pre.refractory', 'ExpCOBA2.pre.spike', 'ExpCOBA2.pre.t_last_spike', 'ExpCOBA2.post.V', 'ExpCOBA2.post.input', 'ExpCOBA2.post.refractory', 'ExpCOBA2.post.spike', 'ExpCOBA2.post.t_last_spike', 'ExpCOBA3.pre.V', 'ExpCOBA3.pre.input', 'ExpCOBA3.pre.refractory', 'ExpCOBA3.pre.spike', 'ExpCOBA3.pre.t_last_spike'].
Loading E.V ...
Loading E.input ...
Loading E.refractory ...
Loading E.spike ...
Loading E.t_last_spike ...
Loading ExpCOBA0.g ...
Loading ExpCOBA0.pre_spike.data ...
Loading ExpCOBA0.pre_spike.in_idx ...
Loading ExpCOBA0.pre_spike.out_idx ...
Loading ExpCOBA1.g ...
Loading ExpCOBA1.pre_spike.data ...
Loading ExpCOBA1.pre_spike.in_idx ...
Loading ExpCOBA1.pre_spike.out_idx ...
Loading ExpCOBA2.g ...
Loading ExpCOBA2.pre_spike.data ...
Loading ExpCOBA2.pre_spike.in_idx ...
Loading ExpCOBA2.pre_spike.out_idx ...
Loading ExpCOBA3.g ...
Loading ExpCOBA3.pre_spike.data ...
Loading ExpCOBA3.pre_spike.in_idx ...
Loading ExpCOBA3.pre_spike.out_idx ...
Loading I.V ...
Loading I.input ...
Loading I.refractory ...
Loading I.spike ...
Loading I.t_last_spike ...

Note

By default, the model variables are retrived by the relative path. Relative path retrival usually results in duplicate variables in the returned TensorCollector. Therefore, there will always be missing keys when loading the variables.

Custom saving and loading#

You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through .vars(). Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of Variables.

The only gotcha to pay attention to is to avoid saving duplicated variables.