Source code for brainpy._src.helpers
from typing import Dict, Callable
from brainpy._src import dynsys
from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.dynsys import DynamicalSystem, DynView
from brainpy._src.math.object_transform.base import StateLoadResult
__all__ = [
'reset_level',
'reset_state',
'load_state',
'save_state',
'clear_input',
]
_max_level = 10
def reset_level(level: int = 0):
"""The decorator for indicating the resetting level.
The function takes an optional integer argument level with a default value of 0.
The lower the level, the earlier the function is called.
>>> import brainpy as bp
>>> bp.reset_level(0)
>>> bp.reset_level(-1)
>>> bp.reset_level(-2)
"""
if level < 0:
level = _max_level + level
if level < 0 or level >= _max_level:
raise ValueError(f'"reset_level" must be an integer in [0, 10). but we got {level}')
def wrap(fun: Callable):
fun.reset_level = level
return fun
return wrap
[docs]
def reset_state(target: DynamicalSystem, *args, **kwargs):
"""Reset states of all children nodes in the given target.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
Args:
target: The target DynamicalSystem.
"""
dynsys.the_top_layer_reset_state = False
try:
nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values())
nodes_with_level = []
# reset node whose `reset_state` has no `reset_level`
for node in nodes:
if not hasattr(node.reset_state, 'reset_level'):
node.reset_state(*args, **kwargs)
else:
nodes_with_level.append(node)
# reset the node's states
for l in range(_max_level):
for node in nodes_with_level:
if node.reset_state.reset_level == l:
node.reset_state(*args, **kwargs)
finally:
dynsys.the_top_layer_reset_state = True
[docs]
def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
"""Copy parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
target: DynamicalSystem. The dynamical system to load its states.
state_dict: dict. A dict containing parameters and persistent buffers.
Returns:
-------
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique()
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
r = node.load_state(state_dict[name], **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
return StateLoadResult(missing_keys, unexpected_keys)
[docs]
def save_state(target: DynamicalSystem, **kwargs) -> Dict:
"""Save all states in the ``target`` as a dictionary for later disk serialization.
Args:
target: DynamicalSystem. The node to save its states.
Returns:
Dict. The state dict for serialization.
"""
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes
return {key: node.save_state(**kwargs) for key, node in nodes.items()}