BrainPyObject#

class brainpy.math.BrainPyObject(name=None)[source]#

The BrainPyObject class for the whole BrainPy ecosystem.

The subclass of BrainPyObject includes but not limited to:

  • DynamicalSystem in brainpy.dyn.base.py

  • Integrator in brainpy.integrators.base.py

  • Optimizer in brainpy.optimizers.py

  • Scheduler in brainpy.optimizers.py

Note

Note a variable created in the BrainPyObject will never be replaced.

For example, if here we create an object which has an attribute a:

>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> class MyObj(bp.BrainPyObject):
>>>   def __init__(self):
>>>     super().__init__()
>>>     self.a = bm.Variable(bm.ones(1))
>>>
>>>   def reset1(self):
>>>     self.a = bm.asarray([10.])
>>>
>>>   def reset2(self):
>>>     self.a = 1.
>>>
>>> ob = MyObj()
>>> id(ob.a)
2643434845056

After we call ob.reset1() function, ob.a is still the original Variable. what’s change is its value.

>>> ob.reset1()
>>> id(ob.a)
2643434845056

What’s really happend when we call self.a = bm.asarray([10.]) is self.a.value = bm.asarray([10.]). Therefore we when call ob.reset2(), there will be an error.

>>> ob.reset2()
brainpy.errors.MathError: The shape of the original data is (1,), while we got () with batch_axis=None.
cpu()[source]#

Move all variable into the CPU device.

cuda()[source]#

Move all variables into the GPU device.

load_state(state_dict, **kwargs)[source]#

Load states from a dictionary.

Return type:

Optional[Tuple[Sequence[str], Sequence[str]]]

load_state_dict(state_dict, warn=True, compatible='v2', **kwargs)[source]#

Copy parameters and buffers from state_dict into this module and its descendants.

Parameters:
  • state_dict (dict) – A dict containing parameters and persistent buffers.

  • warn (bool) – Warnings when there are missing keys or unexpected keys in the external state_dict.

  • compatible (bool) – The version of API for compatibility.

Returns:

outNamedTuple with missing_keys and unexpected_keys fields:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

StateLoadResult

property name#

Name of the model.

nodes(method='absolute', level=-1, include_self=True)[source]#

Collect all children nodes.

Parameters:
  • method (str) – The method to access the nodes.

  • level (int) – The hierarchy level to find nodes.

  • include_self (bool) – Whether include the self.

Returns:

gather – The collection contained (the path, the node).

Return type:

Collector

save_state(**kwargs)[source]#

Save states as a dictionary.

Return type:

Dict

state_dict(**kwargs)[source]#

Returns a dictionary containing a whole state of the module.

Returns:

out – A dictionary containing a whole state of the module.

Return type:

dict

to(device)[source]#

Moves all variables into the given device.

Parameters:

device (Optional[Any]) – The device.

tpu()[source]#

Move all variables into the TPU device.

tracing_variable(name, init, shape, batch_or_mode=None, batch_axis=0, axis_names=None, batch_axis_name='batch')[source]#

Initialize the variable which can be traced during computations and transformations.

Although this function is designed to initialize tracing variables during computation or compilation, it can also be used for the initialization of variables before computation and compilation.

  • If the variable has not been instantiated, a Variable will be instantiated.

  • If the variable has been created, the further call of this function will return the created variable.

Here is the usage example:

class Example(bm.BrainPyObject):
  def fun(self):
    # The first time of calling `.fun()`, this line will create a Variable instance.
    # If users repeatedly call `.fun()` function, this line will not initialize variables again.
    # Instead, it will return the variable has been created.
    self.tracing_variable('a', bm.zeros, (10,))

    # The created variable can be accessed with self.xxx
    self.a.value = bm.ones(10)

    # Calling this function again will not reinitialize the
    # variable again, Instead, it will return the variable
    # that has been created.
    a = self.tracing_variable('a', bm.zeros, (10,))

Added in version 2.4.5.

Parameters:
  • name (str) – str. The variable name.

  • init (Union[Callable, Array, Array]) – callable, Array. The data to be initialized as a Variable.

  • batch_or_mode (Union[int, bool, Mode, None]) – int, bool, Mode. This is used to specify the batch size of this variable. If it is a boolean or an instance of Mode, the batch size will be 1. If it is None, the variable has no batch axis.

  • shape (Union[int, Sequence[int]]) – int, sequence of int. The shape of the variable.

  • batch_axis (int) – int. The batch axis, if batch size is given.

  • axis_names (Optional[Sequence[str]]) – sequence of str. The name for each axis. These names should match the given axes.

  • batch_axis_name (Optional[str]) – str. The name for the batch axis. The name will be used if batch_or_mode is given. Default is brainpy.math.sharding.BATCH_AXIS.

Return type:

Variable

Returns:

The instance of Variable.

train_vars(method='absolute', level=-1, include_self=True)[source]#

The shortcut for retrieving all trainable variables.

Parameters:
  • method (str) – The method to access the variables. Support ‘absolute’ and ‘relative’.

  • level (int) – The hierarchy level to find TrainVar instances.

  • include_self (bool) – Whether include the TrainVar instances in the self.

Returns:

gather – The collection contained (the path, the trainable variable).

Return type:

ArrayCollector

tree_flatten()[source]#

Flattens the object as a PyTree.

The flattening order is determined by attributes added order.

Added in version 2.3.1.

Returns:

res – A tuple of dynamical values and static values.

Return type:

tuple

classmethod tree_unflatten(aux, dynamic_values)[source]#

Unflatten the data to construct an object of this class.

Added in version 2.3.1.

unique_name(name=None, type_=None)[source]#

Get the unique name for this object.

Parameters:
  • name (str, optional) – The expected name. If None, the default unique name will be returned. Otherwise, the provided name will be checked to guarantee its uniqueness.

  • type (str, optional) – The name of this class, used for object naming.

Returns:

name – The unique name for this object.

Return type:

str

vars(method='absolute', level=-1, include_self=True, exclude_types=None)[source]#

Collect all variables in this node and the children nodes.

Parameters:
  • method (str) – The method to access the variables.

  • level (int) – The hierarchy level to find variables.

  • include_self (bool) – Whether include the variables in the self.

  • exclude_types (tuple of type) – The type to exclude.

Returns:

gather – The collection contained (the path, the variable).

Return type:

ArrayCollector