# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any
import brainunit as u
import jax
import numpy as np
from jax import numpy as jnp
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node_class
from brainpy._errors import MathError
from .defaults import defaults
bm = None
__all__ = [
'Array', 'Array', 'ndarray', 'JaxArray', # alias of Array
'ShardedArray',
]
# Ways to change values in a zero-dimensional array
# -----
# Reference: https://stackoverflow.com/questions/56954714/how-do-i-assign-to-a-zero-dimensional-numpy-array
#
# >>> x = np.array(10)
# 1. index the original array with ellipsis or an empty tuple
# >>> x[...] = 2
# >>> x[()] = 2
_all_slice = slice(None, None, None)
def _check_input_array(array):
if isinstance(array, Array):
return array.value
elif isinstance(array, np.ndarray):
return jnp.asarray(array)
else:
return array
def _return(a):
if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0:
return Array(a)
return a
def _as_jax_array_(obj):
return obj.value if isinstance(obj, Array) else obj
def _check_out(out):
if not isinstance(out, Array):
raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}')
def _get_dtype(v):
if hasattr(v, 'dtype'):
dtype = v.dtype
else:
dtype = canonicalize_dtype(type(v))
return dtype
[docs]
@register_pytree_node_class
class Array(u.CustomArray):
"""Multiple-dimensional array in BrainPy.
Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages:
- In-place updating is supported.
>>> import brainpy.math as bm
>>> a = bm.asarray([1, 2, 3.])
>>> a[0] = 10.
- Keep sharding constraints during computation.
- More dense array operations with PyTorch syntax.
"""
__slots__ = ('_value',)
def __init__(self, value, dtype: Any = None):
# array value
if isinstance(value, Array):
value = value.value
elif isinstance(value, (tuple, list, np.ndarray)):
value = jnp.asarray(value)
if dtype is not None:
value = jnp.asarray(value, dtype=dtype)
self._value = value
def __repr__(self) -> str:
print_code = repr(self.value)
if ', dtype' in print_code:
print_code = print_code.split(', dtype')[0] + ')'
prefix = f'{self.__class__.__name__}'
prefix2 = f'{self.__class__.__name__}(value='
if '\n' in print_code:
lines = print_code.split("\n")
blank1 = " " * len(prefix2)
lines[0] = prefix2 + lines[0]
for i in range(1, len(lines)):
lines[i] = blank1 + lines[i]
lines[-1] += ","
blank2 = " " * (len(prefix) + 1)
lines.append(f'{blank2}dtype={self.dtype})')
print_code = "\n".join(lines)
else:
print_code = prefix2 + print_code + f', dtype={self.dtype})'
return print_code
def tree_flatten(self):
return (self.value,), None
@classmethod
def tree_unflatten(cls, aux_data, flat_contents):
return cls(*flat_contents)
# ins = object.__new__(cls)
# ins._value = flat_contents[0]
# return ins
@property
def data(self):
return self.value
@data.setter
def data(self, value):
self.value = value
@property
def value(self):
# return the value
return self._value
@value.setter
def value(self, value):
self_value = self._value
if isinstance(value, Array):
value = value.value
elif isinstance(value, np.ndarray):
value = jnp.asarray(value)
elif isinstance(value, jax.Array):
pass
else:
value = jnp.asarray(value)
# # check
# if value.shape != self_value.shape:
# raise MathError(f"The shape of the original data is {self_value.shape}, "
# f"while we got {value.shape}.")
# if value.dtype != self_value.dtype:
# raise MathError(f"The dtype of the original data is {self_value.dtype}, "
# f"while we got {value.dtype}.")
self._value = value
[docs]
def update(self, value):
"""Update the value of this Array.
"""
self.value = value
def __array__(self, dtype=None):
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value, dtype=dtype)
def __jax_array__(self):
return self.value
[docs]
def as_variable(self):
"""As an instance of Variable."""
global bm
if bm is None: from brainpy import math as bm
return bm.Variable(self)
# ----------------------- #
# JAX methods #
# ----------------------- #
@property
def at(self):
return self.value.at
def block_host_until_ready(self, *args):
return self.value.block_host_until_ready(*args)
def block_until_ready(self, *args):
return self.value.block_until_ready(*args)
def device(self):
return self.value.device()
@property
def device_buffer(self):
return self.value.device_buffer
[docs]
def fill_(self, fill_value):
"""Fill the array with a scalar value.
Args:
fill_value: the scalar value to fill the array.
"""
if isinstance(fill_value, Array):
fill_value = fill_value.value
elif isinstance(fill_value, np.ndarray):
fill_value = jnp.asarray(fill_value)
elif isinstance(fill_value, jax.Array):
pass
else:
fill_value = jnp.asarray(fill_value)
# check
if fill_value.shape != ():
raise MathError(f"The shape of the fill value must be (), "
f"while we got {fill_value.shape}.")
self.value = jnp.full(self.shape, fill_value, dtype=self.dtype)
return self
setattr(Array, "__array_priority__", 100)
JaxArray = Array
ndarray = Array
[docs]
@register_pytree_node_class
class ShardedArray(Array):
"""The sharded array, which stores data across multiple devices.
A drawback of sharding is that the data may not be evenly distributed on shards.
Args:
value: the array value.
dtype: the array type.
keep_sharding: keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True.
"""
__slots__ = ('_value', '_keep_sharding')
def __init__(self, value, dtype: Any = None, *, keep_sharding: bool = True):
super().__init__(value, dtype)
self._keep_sharding = keep_sharding
@property
def value(self):
"""The value stored in this array.
Returns:
The stored data.
"""
v = self._value
# keep sharding constraints
if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None):
return jax.lax.with_sharding_constraint(v, v.sharding)
# return the value
return v
@value.setter
def value(self, value):
self_value = self._value
if isinstance(value, Array):
value = value.value
elif isinstance(value, np.ndarray):
value = jnp.asarray(value)
elif isinstance(value, jax.Array):
pass
else:
value = jnp.asarray(value)
# check
if value.shape != self_value.shape:
raise MathError(f"The shape of the original data is {self_value.shape}, "
f"while we got {value.shape}.")
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
self._value = value