# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This file defines the basic classes for BrainPy object-oriented transformations.
These transformations include JAX's JIT, autograd, vectorization, parallelization, etc.
"""
import numbers
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional
import jax
import numpy as np
from jax.tree_util import register_pytree_node_class
from brainpy.math.defaults import defaults
from brainpy.math.modes import Mode
from brainpy.math.ndarray import (Array, )
from brainpy.math.object_transform.collectors import ArrayCollector, Collector
from brainpy.math.object_transform.naming import get_unique_name, check_name_uniqueness
from brainpy.math.object_transform.variables import (
Variable, VariableView, TrainVar, VarList, VarDict
)
from brainpy.math.sharding import BATCH_AXIS
variable_ = None
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
registered = set()
__all__ = [
'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform',
'NodeDict', 'node_dict', 'NodeList', 'node_list',
]
[docs]
class BrainPyObject(object):
"""The BrainPyObject class for the whole BrainPy ecosystem.
The subclass of BrainPyObject includes but not limited to:
- ``DynamicalSystem`` in *brainpy.dyn.base.py*
- ``Integrator`` in *brainpy.integrators.base.py*
- ``Optimizer`` in *brainpy.optimizers.py*
- ``Scheduler`` in *brainpy.optimizers.py*
.. note::
Note a variable created in the ``BrainPyObject`` will never be replaced.
For example, if here we create an object which has an attribute ``a``:
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> class MyObj(bp.BrainPyObject):
>>> def __init__(self):
>>> super().__init__()
>>> self.a = bm.Variable(bm.ones(1))
>>>
>>> def reset1(self):
>>> self.a = bm.asarray([10.])
>>>
>>> def reset2(self):
>>> self.a = 1.
>>>
>>> ob = MyObj()
>>> id(ob.a)
2643434845056
After we call ``ob.reset1()`` function, ``ob.a`` is still the original Variable.
what's change is its value.
>>> ob.reset1()
>>> id(ob.a)
2643434845056
What's really happend when we call ``self.a = bm.asarray([10.])`` is
``self.a.value = bm.asarray([10.])``. Therefore we when call ``ob.reset2()``,
there will be an error.
>>> ob.reset2()
brainpy.errors.MathError: The shape of the original data is (1,), while we got () with batch_axis=None.
"""
_excluded_vars = ()
def __init__(self, name=None):
super().__init__()
if defaults.bp_object_as_pytree:
if self.__class__ not in registered:
register_pytree_node_class(self.__class__)
registered.add(self.__class__)
# check whether the object has a unique name.
self._name = None
self._name = self.unique_name(name=name)
check_name_uniqueness(name=self._name, obj=self)
# Used to wrap the implicit variables
# which cannot be accessed by self.xxx
self._implicit_vars: Optional[ArrayCollector] = None
# Used to wrap the implicit children nodes
# which cannot be accessed by self.xxx
self._implicit_nodes: Optional[Collector] = None
@property
def implicit_vars(self):
if self._implicit_vars is None:
self._implicit_vars = ArrayCollector()
return self._implicit_vars
@property
def implicit_nodes(self):
if self._implicit_nodes is None:
self._implicit_nodes = Collector()
return self._implicit_nodes
def setattr(self, key: str, value: Any) -> None:
super().__setattr__(key, value)
[docs]
def tracing_variable(
self,
name: str,
init: Union[Callable, Array, jax.Array],
shape: Union[int, Sequence[int]],
batch_or_mode: Union[int, bool, Mode] = None,
batch_axis: int = 0,
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = BATCH_AXIS,
) -> Variable:
"""Initialize a variable that can be traced during computations and transformations.
.. deprecated:: 3.0.0
This feature is no longer supported. Since BrainPy 3.0.0 the library
has been rewritten on top of ``brainstate`` and variable tracing is
handled by ``brainstate`` directly. Calling this method always raises
:class:`NotImplementedError`.
Parameters
----------
name : str
The variable name.
init : callable, Array
The data to be initialized as a ``Variable``.
shape : int, sequence of int
The shape of the variable.
batch_or_mode : int, bool, Mode
The batch size of this variable.
batch_axis : int
The batch axis, if batch size is given.
axis_names : sequence of str
The name for each axis.
batch_axis_name : str
The name for the batch axis.
Raises
------
NotImplementedError
Always, because this feature is unsupported since 3.0.0.
"""
raise NotImplementedError(
'Since 3.0.0, brainpy is rewritten with brainstate. The feature tracing_variable is no longer supported. '
)
def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for changing :py:class:`~.Variable` values.
.. versionadded:: 2.3.1
Parameters
----------
key : str
The attribute.
value : Any
The value.
"""
if key in self.__dict__:
val = self.__dict__[key]
if isinstance(val, Variable):
val.value = value
return
super().__setattr__(key, value)
[docs]
def tree_flatten(self):
"""Flattens the object as a PyTree.
The flattening order is determined by attributes added order.
.. versionadded:: 2.3.1
Returns
-------
res : tuple
A tuple of dynamical values and static values.
"""
dynamic_names = []
dynamic_values = []
static_names = []
static_values = []
for k, v in self.__dict__.items():
if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)):
# if isinstance(v, (BrainPyObject, Variable)):
dynamic_names.append(k)
dynamic_values.append(v)
else:
static_values.append(v)
static_names.append(k)
return tuple(dynamic_values), (tuple(dynamic_names), tuple(static_names), tuple(static_values))
[docs]
@classmethod
def tree_unflatten(cls, aux, dynamic_values):
"""Unflatten the data to construct an object of this class.
.. versionadded:: 2.3.1
"""
self = cls.__new__(cls)
dynamic_names, static_names, static_values = aux
for name, value in zip(dynamic_names, dynamic_values):
object.__setattr__(self, name, value)
for name, value in zip(static_names, static_values):
object.__setattr__(self, name, value)
return self
@property
def name(self):
"""Name of the model."""
return self._name
@name.setter
def name(self, name: str = None):
self._name = self.unique_name(name=name)
check_name_uniqueness(name=self._name, obj=self)
def register_implicit_vars(self, *variables, var_cls: type = None, **named_variables):
if var_cls is None:
var_cls = (Variable, VarList, VarDict)
def _store(key, value):
# ``self.implicit_vars`` is an ``ArrayCollector`` whose entries must
# be plain ``Variable`` instances (it is consumed directly by
# ``vars()``). ``VarList``/``VarDict`` containers are therefore
# flattened into their constituent variables before insertion,
# mirroring how attribute-stored containers are expanded in
# ``vars()``.
if isinstance(value, VarList):
for i, vv in enumerate(value):
self.implicit_vars[f'{key}-{i}'] = vv
elif isinstance(value, VarDict):
for kk, vv in value.items():
self.implicit_vars[f'{key}-{kk}'] = vv
else:
self.implicit_vars[key] = value
for variable in variables:
if isinstance(variable, var_cls):
_store(f'var{id(variable)}', variable)
elif isinstance(variable, (tuple, list)):
for v in variable:
if not isinstance(v, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}')
_store(f'var{id(v)}', v)
elif isinstance(variable, dict):
for k, v in variable.items():
if not isinstance(v, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(v)}')
_store(k, v)
else:
raise ValueError(f'Unknown type: {type(variable)}')
for key, variable in named_variables.items():
if not isinstance(variable, var_cls):
raise ValueError(f'Must be instance of {var_cls}, but we got {type(variable)}')
_store(key, variable)
def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
if node_cls is None:
node_cls = (BrainPyObject, NodeList, NodeDict)
for node in nodes:
if isinstance(node, node_cls):
self.implicit_nodes[node.name] = node
elif isinstance(node, (tuple, list)):
for n in node:
if not isinstance(n, node_cls):
raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
self.implicit_nodes[n.name] = n
elif isinstance(node, dict):
for k, n in node.items():
if not isinstance(n, node_cls):
raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
self.implicit_nodes[k] = n
else:
raise ValueError(f'Unknown type: {type(node)}')
for key, node in named_nodes.items():
if not isinstance(node, node_cls):
raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(node)}')
self.implicit_nodes[key] = node
[docs]
def vars(
self,
method: str = 'absolute',
level: int = -1,
include_self: bool = True,
exclude_types: Tuple[type, ...] = None
):
"""Collect all variables in this node and the children nodes.
Parameters
----------
method : str
The method to access the variables.
level : int
The hierarchy level to find variables.
include_self : bool
Whether include the variables in the self.
exclude_types : tuple of type
The type to exclude.
Returns
-------
gather : ArrayCollector
The collection contained (the path, the variable).
"""
if exclude_types is None:
exclude_types = (VariableView,)
nodes = self.nodes(method=method, level=level, include_self=include_self)
gather = ArrayCollector()
for node_path, node in nodes.items():
for k in node.__dict__.keys():
if k in node._excluded_vars:
continue
v = getattr(node, k)
if isinstance(v, Variable) and not isinstance(v, exclude_types):
gather[f'{node_path}.{k}' if node_path else k] = v
elif isinstance(v, VarList):
for i, vv in enumerate(v):
if not isinstance(vv, exclude_types):
gather[f'{node_path}.{k}-{i}' if node_path else k] = vv
elif isinstance(v, VarDict):
for kk, vv in v.items():
if not isinstance(vv, exclude_types):
gather[f'{node_path}.{k}-{kk}' if node_path else k] = vv
# implicit vars
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
return gather
[docs]
def train_vars(self, method='absolute', level=-1, include_self=True):
"""The shortcut for retrieving all trainable variables.
Parameters
----------
method : str
The method to access the variables. Support 'absolute' and 'relative'.
level : int
The hierarchy level to find TrainVar instances.
include_self : bool
Whether include the TrainVar instances in the self.
Returns
-------
gather : ArrayCollector
The collection contained (the path, the trainable variable).
"""
return self.vars(method=method, level=level, include_self=include_self).subset(TrainVar)
def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _paths=None):
if _paths is None:
_paths = set()
gather = Collector()
if include_self:
if method == 'absolute':
gather[self.name] = self
elif method == 'relative':
gather[''] = self
else:
raise ValueError(f'No support for the method of "{method}".')
if (level > -1) and (_lid >= level):
return gather
if method == 'absolute':
nodes = []
for k, v in self.__dict__.items():
if isinstance(v, BrainPyObject):
_add_node2(self, v, _paths, gather, nodes)
elif isinstance(v, NodeList):
for v2 in v:
_add_node2(self, v2, _paths, gather, nodes)
elif isinstance(v, NodeDict):
for v2 in v.values():
if isinstance(v2, BrainPyObject):
_add_node2(self, v2, _paths, gather, nodes)
# implicit nodes
for node in self.implicit_nodes.values():
_add_node2(self, node, _paths, gather, nodes)
# finding nodes recursively
for v in nodes:
gather.update(v._find_nodes(method=method,
level=level,
_lid=_lid + 1,
_paths=_paths,
include_self=include_self))
elif method == 'relative':
nodes = []
for k, v in self.__dict__.items():
if isinstance(v, BrainPyObject):
_add_node1(self, k, v, _paths, gather, nodes)
elif isinstance(v, NodeList):
for i, v2 in enumerate(v):
_add_node1(self, k + '-' + str(i), v2, _paths, gather, nodes)
elif isinstance(v, NodeDict):
for k2, v2 in v.items():
if isinstance(v2, BrainPyObject):
_add_node1(self, k + '.' + k2, v2, _paths, gather, nodes)
# implicit nodes
for key, node in self.implicit_nodes.items():
_add_node1(self, key, node, _paths, gather, nodes)
# finding nodes recursively
for k1, v1 in nodes:
for k2, v2 in v1._find_nodes(method=method,
_paths=_paths,
_lid=_lid + 1,
level=level,
include_self=include_self).items():
if k2:
gather[f'{k1}.{k2}'] = v2
else:
raise ValueError(f'No support for the method of "{method}".')
return gather
[docs]
def nodes(self, method='absolute', level=-1, include_self=True):
"""Collect all children nodes.
Parameters
----------
method : str
The method to access the nodes.
level : int
The hierarchy level to find nodes.
include_self : bool
Whether include the self.
Returns
-------
gather : Collector
The collection contained (the path, the node).
"""
return self._find_nodes(method=method, level=level, include_self=include_self)
[docs]
def unique_name(self, name=None, type_=None):
"""Get the unique name for this object.
Parameters
----------
name : str, optional
The expected name. If None, the default unique name will be returned.
Otherwise, the provided name will be checked to guarantee its uniqueness.
type_ : str, optional
The name of this class, used for object naming.
Returns
-------
name : str
The unique name for this object.
"""
if name is None:
if type_ is None:
return get_unique_name(type_=self.__class__.__name__)
else:
return get_unique_name(type_=type_)
else:
check_name_uniqueness(name=name, obj=self)
return name
[docs]
def save_state(self, **kwargs) -> Dict:
"""Save states as a dictionary. """
return self.__save_state__(**kwargs)
[docs]
def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
"""Load states from a dictionary."""
return self.__load_state__(state_dict, **kwargs)
def __save_state__(self, **kwargs) -> Dict:
"""Save states. """
return self.vars(include_self=True, level=0).unique().dict()
def __load_state__(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
"""Load states from the external objects."""
variables = self.vars(include_self=True, level=0).unique()
keys1 = set(state_dict.keys())
keys2 = set(variables.keys())
for key in keys2.intersection(keys1):
variables[key].value = jax.numpy.asarray(state_dict[key])
unexpected_keys = list(keys1 - keys2)
missing_keys = list(keys2 - keys1)
return missing_keys, unexpected_keys
[docs]
def state_dict(self, **kwargs) -> dict:
"""Returns a dictionary containing a whole state of the module.
Returns
-------
out : dict
A dictionary containing a whole state of the module.
"""
nodes = self.nodes() # retrieve all nodes
return {key: node.save_state(**kwargs) for key, node in nodes.items()}
[docs]
def load_state_dict(
self,
state_dict: Dict[str, Any],
warn: bool = True,
compatible: str = 'v2',
**kwargs,
):
"""Copy parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Parameters
----------
state_dict : dict
A dict containing parameters and persistent buffers.
warn : bool
Warnings when there are missing keys or unexpected keys in the external ``state_dict``.
compatible : bool
The version of API for compatibility.
Returns
-------
out : StateLoadResult
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
if compatible == 'v1':
variables = self.vars().unique()
keys1 = set(state_dict.keys())
keys2 = set(variables.keys())
unexpected_keys = list(keys1 - keys2)
missing_keys = list(keys2 - keys1)
for key in keys2.intersection(keys1):
variables[key].value = jax.numpy.asarray(state_dict[key])
elif compatible == 'v2':
nodes = self.nodes()
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
r = node.load_state(state_dict[name] if name in state_dict else {}, **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
else:
raise ValueError(f'Unknown compatible version: {compatible}')
if warn:
if len(unexpected_keys):
warnings.warn(f'Unexpected keys in state_dict: {unexpected_keys}', UserWarning)
if len(missing_keys):
warnings.warn(f'Missing keys in state_dict: {missing_keys}', UserWarning)
return StateLoadResult(missing_keys, unexpected_keys)
[docs]
def to(self, device: Optional[Any]):
"""Moves all variables into the given device.
Parameters
----------
device : Optional[Any]
The device.
"""
# Iterate over the actual ``Variable`` instances (not the nested
# ``state_dict`` mapping). Iterating ``state_dict()`` would yield
# nested dicts/raw arrays keyed by name, so nothing was ever moved and
# ``setattr`` injected junk attributes onto the object.
for var in self.vars().values():
var.value = jax.device_put(var.value, device=device)
return self
[docs]
def cpu(self):
"""Move all variable into the CPU device."""
return self.to(device=jax.devices('cpu')[0])
[docs]
def cuda(self):
"""Move all variables into the GPU device."""
return self.to(device=jax.devices('gpu')[0])
[docs]
def tpu(self):
"""Move all variables into the TPU device."""
return self.to(device=jax.devices('tpu')[0])
def _add_node2(self, v, _paths, gather, nodes):
path = (id(self), id(v))
if path not in _paths:
_paths.add(path)
gather[v.name] = v
nodes.append(v)
def _add_node1(self, k, v, _paths, gather, nodes):
path = (id(self), id(v))
if path not in _paths:
_paths.add(path)
gather[k] = v
nodes.append((k, v))
Base = BrainPyObject
[docs]
class FunAsObject(BrainPyObject):
"""Transform a Python function as a :py:class:`~.BrainPyObject`.
Parameters
----------
target : callable
The function to wrap.
child_objs : optional, BrainPyObject, sequence of BrainPyObject, dict
The nodes in the defined function ``f``.
dyn_vars : optional, Variable, sequence of Variable, dict
The dynamically changed variables.
name : optional, str
The function name.
"""
def __init__(
self,
target: Callable,
child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None,
name: str = None
):
super(FunAsObject, self).__init__(name=name)
self.target = target
if child_objs is not None:
self.register_implicit_nodes(child_objs)
if dyn_vars is not None:
self.register_implicit_vars(dyn_vars)
def __call__(self, *args, **kwargs):
return self.target(*args, **kwargs)
def __repr__(self) -> str:
from brainpy.tools import repr_context
name = self.__class__.__name__
indent = " " * (len(name) + 1)
indent2 = indent + " " * len('nodes=')
nodes = [repr_context(str(n), indent2) for n in self.implicit_nodes.values()]
node_string = ", \n".join(nodes)
return (f'{name}(nodes=[{node_string}],\n' +
" " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')
def _check_elem(elem):
if not isinstance(elem, (numbers.Number, jax.Array, Array, np.ndarray, BrainPyObject)):
raise TypeError(f'Element should be a numerical value or a BrainPyObject.')
return elem
def _check_num_elem(elem):
if not isinstance(elem, (numbers.Number, jax.Array, Array, np.ndarray)):
raise TypeError(f'Element should be a numerical value.')
return elem
def _check_obj_elem(elem):
if not isinstance(elem, BrainPyObject):
raise TypeError(f'Element should be a {BrainPyObject.__class__.__name__}.')
return elem
class ObjectTransform(BrainPyObject):
"""Object-oriented JAX transformation for BrainPy computation.
"""
def __init__(self, name: str = None):
super().__init__(name=name)
def __call__(self, *args, **kwargs):
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__
[docs]
class NodeList(list):
"""A sequence of :py:class:`~.BrainPyObject`, which is compatible with
:py:func:`~.vars()` and :py:func:`~.nodes()` operations in a :py:class:`~.BrainPyObject`.
That is to say, any nodes that are wrapped into :py:class:`~.NodeList` will be automatically
retieved when using :py:func:`~.nodes()` function.
>>> import brainpy as bp
>>> l = bm.node_list([bp.dnn.Dense(1, 2),
>>> bp.dnn.LSTMCell(2, 3)])
"""
def __init__(self, seq=()):
super().__init__()
self.extend(seq)
[docs]
def append(self, element) -> 'NodeList':
# if not isinstance(element, BrainPyObject):
# raise TypeError(f'element must be an instance of {BrainPyObject.__name__}.')
super().append(element)
return self
[docs]
def extend(self, iterable) -> 'NodeList':
for element in iterable:
self.append(element)
return self
node_list = NodeList
[docs]
class NodeDict(dict):
"""A dictionary of :py:class:`~.BrainPyObject`, which is compatible with
:py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`.
"""
# def _check_elem(self, elem):
# if not isinstance(elem, BrainPyObject):
# raise TypeError(f'Element should be {BrainPyObject.__name__}, but got {type(elem)}.')
# return elem
def __init__(self, *args, check_unique: bool = False, **kwargs):
super().__init__()
self.check_unique = check_unique
self.update(*args, **kwargs)
[docs]
def update(self, *args, **kwargs) -> 'NodeDict':
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
elif isinstance(arg, tuple):
assert len(arg) == 2
self[arg[0]] = arg[1]
for k, v in kwargs.items():
self[k] = v
return self
def __setitem__(self, key, value) -> 'NodeDict':
if self.check_unique:
exist = self.get(key, None)
if exist is not None and id(exist) != id(value):
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {exist}.')
super().__setitem__(key, value)
return self
node_dict = NodeDict