# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Any, Sequence
import brainstate
import jax
import numpy as np
from brainstate._state import record_state_value_read, record_state_value_write
from jax import numpy as jnp
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class
from brainpy._errors import MathError
from brainpy.math.ndarray import Array
from brainpy.math.sharding import BATCH_AXIS
__all__ = [
'Variable',
'TrainVar',
'Parameter',
'VariableView',
'VarList', 'var_list',
'VarDict', 'var_dict',
]
[docs]
@register_pytree_node_class
class Variable(brainstate.State, Array):
"""The pointer to specify the dynamical variable.
Initializing an instance of ``Variable`` by two ways:
>>> import brainpy.math as bm
>>> # 1. init a Variable by the concreate data
>>> v1 = bm.Variable(bm.zeros(10))
>>> # 2. init a Variable by the data shape
>>> v2 = bm.Variable(10)
Note that when initializing a `Variable` by the data shape,
all values in this `Variable` will be initialized as zeros.
Args:
value_or_size: Shape, Array, int. The value or the size of the value.
dtype: Any. The type of the data.
batch_axis: optional, int. The batch axis.
axis_names: sequence of str. The name for each axis.
"""
def __init__(
self,
value_or_size: Any,
dtype: type = None,
batch_axis: int = None,
*,
axis_names: Optional[Sequence[str]] = None,
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]):
value = jnp.zeros(value_or_size, dtype=dtype)
else:
value = value_or_size
if isinstance(value, Array):
value = value.value
Array.__init__(self, value, dtype=dtype)
brainstate.State.__init__(self, value)
# check batch axis
if isinstance(value, Variable):
if value.batch_axis is not None and batch_axis is not None:
if batch_axis != value.batch_axis:
raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value '
f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}')
batch_axis = value.batch_axis
# assign batch axis
self._batch_axis = batch_axis
if batch_axis is not None:
if batch_axis >= np.ndim(self._value):
raise MathError(f'This variables has {np.ndim(self._value)} dimension, '
f'but the batch axis is set to be {batch_axis}.')
# ready to trace the variable
if axis_names is not None:
if len(axis_names) + 1 == self.ndim:
axis_names = list(axis_names)
axis_names.insert(self.batch_axis, BATCH_AXIS)
assert len(axis_names) == self.ndim
axis_names = tuple(axis_names)
self.axis_names = axis_names
@property
def size_without_batch(self):
if self.batch_axis is None:
return self.size
else:
sizes = self.size
return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:]
@property
def batch_axis(self) -> Optional[int]:
return self._batch_axis
@batch_axis.setter
def batch_axis(self, val):
raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.')
@property
def batch_size(self) -> Optional[int]:
if self.batch_axis is None:
return None
else:
return self.shape[self.batch_axis]
@batch_size.setter
def batch_size(self, val):
raise ValueError(f'Cannot set "batch_size" manually.')
def _ensure_value_exists(self):
pass
@property
def value(self):
self._ensure_value_exists()
record_state_value_read(self)
return self._read_value()
@value.setter
def value(self, v):
_value = self.value
ext_shape = jnp.shape(v)
int_shape = jnp.shape(_value)
if self._batch_axis is not None:
ext_shape = ext_shape[:self._batch_axis] + ext_shape[self._batch_axis + 1:]
int_shape = int_shape[:self._batch_axis] + int_shape[self._batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {int_shape}, while we got {ext_shape}"
error += f' with batch_axis={self._batch_axis}.'
raise MathError(error)
ext_dtype = _get_dtype(v)
int_dtype = self.dtype
if ext_dtype != int_dtype:
raise MathError(f"The dtype of the original data is {int_dtype}, "
f"while we got {ext_dtype}.")
if isinstance(v, Array):
v = v.value
elif isinstance(v, np.ndarray):
v = jnp.asarray(v)
else:
v = v
if isinstance(v, brainstate.State): # value checking
v = v.value
self._check_value_tree(v) # check the tree structure
record_state_value_write(self) # record the value by the stack (>= level)
self._been_writen = True # set the flag
self._write_value(v) # write the value
def _get_dtype(v):
if hasattr(v, 'dtype'):
dtype = v.dtype
else:
dtype = canonicalize_dtype(type(v))
return dtype
def _as_jax_array_(obj):
return obj.value if isinstance(obj, Array) else obj
[docs]
@register_pytree_node_class
class TrainVar(Variable):
"""The pointer to specify the trainable variable.
"""
def __init__(
self,
value_or_size: Any,
dtype: type = None,
batch_axis: int = None,
*,
axis_names: Optional[Sequence[str]] = None,
):
super().__init__(
value_or_size,
dtype=dtype,
batch_axis=batch_axis,
axis_names=axis_names,
)
[docs]
@register_pytree_node_class
class Parameter(Variable):
"""The pointer to specify the parameter.
"""
def __init__(
self,
value_or_size: Any,
dtype: type = None,
batch_axis: int = None,
*,
axis_names: Optional[Sequence[str]] = None,
):
super().__init__(
value_or_size,
dtype=dtype,
batch_axis=batch_axis,
axis_names=axis_names,
)
[docs]
class VariableView(Variable):
"""A view of a Variable instance.
This class is used to create a subset view of ``brainpy.math.Variable``.
>>> import brainpy.math as bm
>>> bm.random.seed(123)
>>> origin = bm.Variable(bm.random.random(5))
>>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2]
VariableView([0.02920651, 0.19066381], dtype=float32)
``VariableView`` can be used to update the subset of the original
Variable instance, and make operations on this subset of the Variable.
>>> view[:] = 1.
>>> view
VariableView([1., 1.], dtype=float32)
>>> origin
Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32)
>>> view + 10
Array([11., 11.], dtype=float32)
>>> view *= 10
VariableView([10., 10.], dtype=float32)
The above example demonstrates that the updating of an ``VariableView`` instance
is actually made in the original ``Variable`` instance.
Moreover, it's worthy to note that ``VariableView`` is not a PyTree.
"""
_need_record = False
def __init__(
self,
value: Variable,
index: Any,
):
self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, Array))
if not isinstance(value, Variable):
raise ValueError('Must be instance of Variable.')
super().__init__(value.value, batch_axis=value.batch_axis)
self._value = value
def __repr__(self) -> str:
print_code = repr(self._value)
prefix = f'{self.__class__.__name__}'
blank = " " * (len(prefix) + 1)
lines = print_code.split("\n")
lines[0] = prefix + "(" + lines[0]
for i in range(1, len(lines)):
lines[i] = blank + lines[i]
lines[-1] += ","
lines.append(blank + f'index={self.index})')
print_code = "\n".join(lines)
return print_code
@property
def value(self):
return self._value[self.index]
@value.setter
def value(self, v):
int_shape = self.shape
if self.batch_axis is None:
ext_shape = v.shape
else:
ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:]
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {self.shape}, while we got {v.shape}"
if self.batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self.batch_axis}.'
raise MathError(error)
if v.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {v.dtype}.")
self._value[self.index] = v.value if isinstance(v, Array) else v
[docs]
@register_pytree_node_class
class VarList(list):
"""A sequence of :py:class:`~.Variable`, which is compatible with
:py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`.
Actually, :py:class:`~.VarList` is a python list.
:py:class:`~.VarList` is specifically designed to store Variable instances.
"""
def __init__(self, seq=()):
super().__init__()
self.extend(seq)
[docs]
def append(self, element) -> 'VarList':
if not isinstance(element, Variable):
raise TypeError(f'element must be an instance of {Variable.__name__}.')
super().append(element)
return self
[docs]
def extend(self, iterable) -> 'VarList':
for element in iterable:
self.append(element)
return self
def __setitem__(self, key, value) -> 'VarList':
"""Override the item setting.
This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden,
and only the value can be changed for each element.
>>> import brainpy.math as bm
>>> l = bm.var_list([bm.Variable(1), bm.Variable(2)])
>>> print(id(l[0]), id(l[1]))
2077748389472 2077748389552
>>> l[1] = bm.random.random(2)
>>> l[0] = bm.random.random(1)
>>> print(id(l[0]), id(l[1])) # still the original Variable instances
2077748389472 2077748389552
"""
if isinstance(key, int):
self[key].value = value
else:
super().__setitem__(key, value)
return self
def tree_flatten(self):
return tuple(self), None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(children)
var_list = VarList
[docs]
@register_pytree_node_class
class VarDict(dict):
"""A dictionary of :py:class:`~.Variable`, which is compatible with
:py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`.
Actually, :py:class:`~.VarDict` is a python dict.
:py:class:`~.VarDict` is specifically designed to store Variable instances.
"""
def _check_elem(self, elem):
if not isinstance(elem, Variable):
raise TypeError(f'Element should be {Variable.__name__}, but got {type(elem)}.')
return elem
def __init__(self, *args, **kwargs):
super().__init__()
self.update(*args, **kwargs)
[docs]
def update(self, *args, **kwargs) -> 'VarDict':
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
elif isinstance(arg, tuple):
assert len(arg) == 2
self[arg[0]] = arg[1]
for k, v in kwargs.items():
self[k] = v
return self
def __setitem__(self, key, value) -> 'VarDict':
"""Override the item setting.
This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden.
>>> import brainpy.math as bm
>>> d = bm.var_dict({'a': bm.Variable(1), 'b': bm.Variable(2)})
>>> print(id(d['a']), id(d['b']))
2077667833504 2077748488176
>>> d['b'] = bm.random.random(2)
>>> d['a'] = bm.random.random(1)
>>> print(id(d['a']), id(d['b'])) # still the original Variable instances
2077667833504 2077748488176
"""
if key in self:
self[key].value = value
else:
super().__setitem__(key, self._check_elem(value))
return self
def tree_flatten(self):
return tuple(self.values()), tuple(self.keys())
@classmethod
def tree_unflatten(cls, keys, values):
return cls(jax.util.safe_zip(keys, values))
var_dict = VarDict