Source code for brainpy._src.math.object_transform.variables

from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple

import jax
import numpy as np
from jax import numpy as jnp
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class

from brainpy._src.math.ndarray import Array
from brainpy._src.math.sharding import BATCH_AXIS
from brainpy.errors import MathError

__all__ = [
  'Variable',
  'TrainVar',
  'Parameter',
  'VariableView',

  'VarList', 'var_list',
  'VarDict', 'var_dict',
]


[docs] class VariableStack(dict): """Variable stack, for collecting all :py:class:`~.Variable` used in the program. :py:class:`~.VariableStack` supports all features of python dict. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._values = dict()
[docs] def add(self, var: 'Variable'): """Add a new :py:class:`~.Variable`.""" assert isinstance(var, Variable), f'must be instance of {Variable}' id_ = id(var) if id_ not in self: self[id_] = var self._values[id_] = var._value
[docs] def collect_values(self): """Collect the value of each variable once again.""" for id_, var in self.items(): self._values[id_] = var._value
[docs] def assign_org_values(self): """Assign the original value for each variable.""" for id_, var in self.items(): if id_ in self._values: var._value = self._values[id_]
[docs] def assign(self, data: Union[Dict, Sequence], check: bool = True): """Assign the value for each :math:`~.Variable` according to the given ``data``. Args: data: dict, list, tuple. The data of all variables check: bool. Check whether the shape and type of the given data are consistent with original data. """ if isinstance(data, dict): assert len(data) == len(self), 'Data length mismatch. ' if check: for id_, elem in self.items(): elem.value = data[id_] else: for id_, elem in self.items(): elem._value = data[id_] elif isinstance(data, (tuple, list)): assert len(data) == len(self), 'Data length mismatch. ' if check: for i, elem in enumerate(self.values()): elem.value = data[i] else: for i, elem in enumerate(self.values()): elem._value = data[i] else: raise TypeError
[docs] def call_on_subset(self, cond: Callable, call: Callable) -> dict: """Call a function on the subset of this :py:class:`~VariableStack`. >>> import brainpy.math as bm >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) >>> stack.call_on_subset(lambda a: isinstance(a, bm.random.RandomState), >>> lambda a: a.split_key()) {'b': Array([3819641963, 2025898573], dtype=uint32)} Args: cond: The function to determine whether the element belongs to the wanted subset. call: The function to call if the element belongs to the wanted subset. Returns: A dict containing the results of ``call`` function for each element in the ``cond`` constrained subset. """ res = dict() for id_, elem in self.items(): if cond(elem): res[id_] = call(elem) return res
[docs] def separate_by_instance(self, cls: type) -> Tuple['VariableStack', 'VariableStack']: """Separate all variables into two groups: (variables that are instances of the given ``cls``, variables that are not instances of the given ``cls``). >>> import brainpy.math as bm >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) >>> stack.separate_by_instance(bm.random.RandomState) ({'b': RandomState(key=([0, 1], dtype=uint32))}, {'a': Variable(value=Array([0.]), dtype=float32)}) >>> stack.separate_by_instance(bm.Variable) ({'a': Variable(value=Array([0.]), dtype=float32), 'b': RandomState(key=([0, 1], dtype=uint32))}, {}) Args: cls: The class type. Returns: A tuple with two elements: - VariableStack of variables that are instances of the given ``cls`` - VariableStack of variables that are not instances of the given ``cls`` """ is_instances = type(self)() not_instances = type(self)() for id_, elem in self.items(): if isinstance(elem, cls): is_instances[id_] = elem else: not_instances[id_] = elem return is_instances, not_instances
[docs] def subset_by_instance(self, cls: type) -> 'VariableStack': """Collect all variables which are instances of the given class type.""" new_dict = type(self)() for id_, elem in self.items(): if isinstance(elem, cls): new_dict[id_] = elem return new_dict
[docs] def subset_by_not_instance(self, cls: type) -> 'VariableStack': """Collect all variables which are not instance of the given class type.""" new_dict = type(self)() for id_, elem in self.items(): if not isinstance(elem, cls): new_dict[id_] = elem return new_dict
instance_of = subset_by_instance not_instance_of = subset_by_not_instance
[docs] def dict_data_of_subset(self, subset_cond: Callable) -> dict: """Get data of the given subset constrained by function ``subset_cond``. Args: subset_cond: A function to determine whether the element is in the subset wanted. Returns: A dict of data for elements of the wanted subset. """ res = dict() for id_, elem in self.items(): if subset_cond(elem): res[id_] = elem.value return res
[docs] def dict_data(self) -> dict: """Get all data in the collected variables with a python dict structure.""" new_dict = dict() for id_, elem in tuple(self.items()): new_dict[id_] = elem.value return new_dict
[docs] def list_data(self) -> list: """Get all data in the collected variables with a python list structure.""" new_list = list() for elem in tuple(self.values()): new_list.append(elem.value if isinstance(elem, Array) else elem) return new_list
[docs] def remove_by_id(self, *ids, error_when_absent=False): """Remove or pop variables in the stack by the given ids.""" if error_when_absent: for id_ in ids: self.pop(id_) else: for id_ in ids: self.pop(id_, None)
remove_var_by_id = remove_by_id @classmethod def num_of_stack(self): return len(var_stack_list) @classmethod def is_first_stack(self): return len(var_stack_list) == 0 def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: var_stack_list.pop() self.assign_org_values() # reassign the original value for each variable self._values.clear() def __add__(self, other: dict): new_dict = VariableStack(self) new_dict.update(other) new_dict._values.update(self._values) if isinstance(other, VariableStack): new_dict._values.update(other._values) return new_dict
var_stack_list: List[VariableStack] = []
[docs] @register_pytree_node_class class Variable(Array): """The pointer to specify the dynamical variable. Initializing an instance of ``Variable`` by two ways: >>> import brainpy.math as bm >>> # 1. init a Variable by the concreate data >>> v1 = bm.Variable(bm.zeros(10)) >>> # 2. init a Variable by the data shape >>> v2 = bm.Variable(10) Note that when initializing a `Variable` by the data shape, all values in this `Variable` will be initialized as zeros. Args: value_or_size: Shape, Array, int. The value or the size of the value. dtype: Any. The type of the data. batch_axis: optional, int. The batch axis. axis_names: sequence of str. The name for each axis. """ __slots__ = ('_value', '_batch_axis', 'ready_to_trace', 'axis_names') def __init__( self, value_or_size: Any, dtype: type = None, batch_axis: int = None, *, axis_names: Optional[Sequence[str]] = None, ready_to_trace: bool = None ): if isinstance(value_or_size, int): value = jnp.zeros(value_or_size, dtype=dtype) elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]): value = jnp.zeros(value_or_size, dtype=dtype) else: value = value_or_size super().__init__(value, dtype=dtype) # check batch axis if isinstance(value, Variable): if value.batch_axis is not None and batch_axis is not None: if batch_axis != value.batch_axis: raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value ' f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}') batch_axis = value.batch_axis # assign batch axis self._batch_axis = batch_axis if batch_axis is not None: if batch_axis >= np.ndim(self._value): raise MathError(f'This variables has {np.ndim(self._value)} dimension, ' f'but the batch axis is set to be {batch_axis}.') # ready to trace the variable if ready_to_trace is None: if len(var_stack_list) == 0: self.ready_to_trace = True else: self.ready_to_trace = False else: self.ready_to_trace = ready_to_trace if axis_names is not None: if len(axis_names) + 1 == self.ndim: axis_names = list(axis_names) axis_names.insert(self.batch_axis, BATCH_AXIS) assert len(axis_names) == self.ndim axis_names = tuple(axis_names) self.axis_names = axis_names @property def size_without_batch(self): if self.batch_axis is None: return self.size else: sizes = self.size return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] @property def batch_axis(self) -> Optional[int]: return self._batch_axis @batch_axis.setter def batch_axis(self, val): raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.') @property def batch_size(self) -> Optional[int]: if self.batch_axis is None: return None else: return self.shape[self.batch_axis] @batch_size.setter def batch_size(self, val): raise ValueError(f'Cannot set "batch_size" manually.') @property def value(self): self._append_to_stack() return self._value @value.setter def value(self, v): _value = self.value ext_shape = jnp.shape(v) int_shape = jnp.shape(_value) if self._batch_axis is not None: ext_shape = ext_shape[:self._batch_axis] + ext_shape[self._batch_axis + 1:] int_shape = int_shape[:self._batch_axis] + int_shape[self._batch_axis + 1:] if ext_shape != int_shape: error = f"The shape of the original data is {int_shape}, while we got {ext_shape}" error += f' with batch_axis={self._batch_axis}.' raise MathError(error) ext_dtype = _get_dtype(v) int_dtype = self.dtype if ext_dtype != int_dtype: raise MathError(f"The dtype of the original data is {int_dtype}, " f"while we got {ext_dtype}.") self._append_to_stack() if isinstance(v, Array): v = v.value elif isinstance(v, np.ndarray): v = jnp.asarray(v) else: v = v self._value = v def _append_to_stack(self): if self.ready_to_trace: for stack in var_stack_list: stack.add(self)
[docs] def tree_flatten(self): """Flattens this variable. Returns: A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree. """ return (self._value,), None
[docs] @classmethod def tree_unflatten(cls, aux_data, flat_contents): """Reconstructs a variable from the aux_data and the leaves. Args: aux_data: flat_contents: Returns: The variable. """ return cls(*flat_contents, ready_to_trace=False)
[docs] def clone(self) -> 'Variable': """Clone the variable. """ r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis) r.ready_to_trace = self.ready_to_trace return r
def _get_dtype(v): if hasattr(v, 'dtype'): dtype = v.dtype else: dtype = canonicalize_dtype(type(v)) return dtype def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj
[docs] @register_pytree_node_class class TrainVar(Variable): """The pointer to specify the trainable variable. """ def __init__( self, value_or_size: Any, dtype: type = None, batch_axis: int = None, *, axis_names: Optional[Sequence[str]] = None, ready_to_trace: bool = True ): super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, ready_to_trace=ready_to_trace, axis_names=axis_names, )
[docs] @register_pytree_node_class class Parameter(Variable): """The pointer to specify the parameter. """ def __init__( self, value_or_size: Any, dtype: type = None, batch_axis: int = None, *, axis_names: Optional[Sequence[str]] = None, ready_to_trace: bool = True ): super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, ready_to_trace=ready_to_trace, axis_names=axis_names, )
[docs] class VariableView(Variable): """A view of a Variable instance. This class is used to create a subset view of ``brainpy.math.Variable``. >>> import brainpy.math as bm >>> bm.random.seed(123) >>> origin = bm.Variable(bm.random.random(5)) >>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2] VariableView([0.02920651, 0.19066381], dtype=float32) ``VariableView`` can be used to update the subset of the original Variable instance, and make operations on this subset of the Variable. >>> view[:] = 1. >>> view VariableView([1., 1.], dtype=float32) >>> origin Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32) >>> view + 10 Array([11., 11.], dtype=float32) >>> view *= 10 VariableView([10., 10.], dtype=float32) The above example demonstrates that the updating of an ``VariableView`` instance is actually made in the original ``Variable`` instance. Moreover, it's worthy to note that ``VariableView`` is not a PyTree. """ _need_record = False def __init__( self, value: Variable, index: Any, ): self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, Array)) if not isinstance(value, Variable): raise ValueError('Must be instance of Variable.') super().__init__(value.value, batch_axis=value.batch_axis, ready_to_trace=False) self._value = value def __repr__(self) -> str: print_code = repr(self._value) prefix = f'{self.__class__.__name__}' blank = " " * (len(prefix) + 1) lines = print_code.split("\n") lines[0] = prefix + "(" + lines[0] for i in range(1, len(lines)): lines[i] = blank + lines[i] lines[-1] += "," lines.append(blank + f'index={self.index})') print_code = "\n".join(lines) return print_code @property def value(self): return self._value[self.index] @value.setter def value(self, v): int_shape = self.shape if self.batch_axis is None: ext_shape = v.shape else: ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:] int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] if ext_shape != int_shape: error = f"The shape of the original data is {self.shape}, while we got {v.shape}" if self.batch_axis is None: error += '. Do you forget to set "batch_axis" when initialize this variable?' else: error += f' with batch_axis={self.batch_axis}.' raise MathError(error) if v.dtype != self._value.dtype: raise MathError(f"The dtype of the original data is {self._value.dtype}, " f"while we got {v.dtype}.") self._value[self.index] = v.value if isinstance(v, Array) else v
[docs] @register_pytree_node_class class VarList(list): """A sequence of :py:class:`~.Variable`, which is compatible with :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. Actually, :py:class:`~.VarList` is a python list. :py:class:`~.VarList` is specifically designed to store Variable instances. """ def __init__(self, seq=()): super().__init__() self.extend(seq)
[docs] def append(self, element) -> 'VarList': if not isinstance(element, Variable): raise TypeError(f'element must be an instance of {Variable.__name__}.') super().append(element) return self
[docs] def extend(self, iterable) -> 'VarList': for element in iterable: self.append(element) return self
def __setitem__(self, key, value) -> 'VarList': """Override the item setting. This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden, and only the value can be changed for each element. >>> import brainpy.math as bm >>> l = bm.var_list([bm.Variable(1), bm.Variable(2)]) >>> print(id(l[0]), id(l[1])) 2077748389472 2077748389552 >>> l[1] = bm.random.random(2) >>> l[0] = bm.random.random(1) >>> print(id(l[0]), id(l[1])) # still the original Variable instances 2077748389472 2077748389552 """ if isinstance(key, int): self[key].value = value else: super().__setitem__(key, value) return self def tree_flatten(self): return tuple(self), None @classmethod def tree_unflatten(cls, aux_data, children): return cls(children)
var_list = VarList
[docs] @register_pytree_node_class class VarDict(dict): """A dictionary of :py:class:`~.Variable`, which is compatible with :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. Actually, :py:class:`~.VarDict` is a python dict. :py:class:`~.VarDict` is specifically designed to store Variable instances. """ def _check_elem(self, elem): if not isinstance(elem, Variable): raise TypeError(f'Element should be {Variable.__name__}, but got {type(elem)}.') return elem def __init__(self, *args, **kwargs): super().__init__() self.update(*args, **kwargs)
[docs] def update(self, *args, **kwargs) -> 'VarDict': for arg in args: if isinstance(arg, dict): for k, v in arg.items(): self[k] = v elif isinstance(arg, tuple): assert len(arg) == 2 self[arg[0]] = args[1] for k, v in kwargs.items(): self[k] = v return self
def __setitem__(self, key, value) -> 'VarDict': """Override the item setting. This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden. >>> import brainpy.math as bm >>> d = bm.var_dict({'a': bm.Variable(1), 'b': bm.Variable(2)}) >>> print(id(d['a']), id(d['b'])) 2077667833504 2077748488176 >>> d['b'] = bm.random.random(2) >>> d['a'] = bm.random.random(1) >>> print(id(d['a']), id(d['b'])) # still the original Variable instances 2077667833504 2077748488176 """ if key in self: self[key].value = value else: super().__setitem__(key, self._check_elem(value)) return self def tree_flatten(self): return tuple(self.values()), tuple(self.keys()) @classmethod def tree_unflatten(cls, keys, values): return cls(jax.util.safe_zip(keys, values))
var_dict = VarDict