Source code for brainpy.math.ndarray

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any

import brainunit as u
import jax
import numpy as np
from jax import numpy as jnp
from jax.tree_util import register_pytree_node_class

from brainpy._errors import MathError
from .defaults import defaults

bm = None

__all__ = [
    'Array', 'ndarray', 'JaxArray',  # alias of Array
    'ShardedArray',
]


def _return(a):
    if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0:
        return Array(a)
    return a


def _as_jax_array_(obj):
    return obj.value if isinstance(obj, Array) else obj


def _check_out(out):
    if not isinstance(out, Array):
        raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}')


[docs] @register_pytree_node_class class Array(u.CustomArray): """Multiple-dimensional array in BrainPy. Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages: - In-place updating is supported. >>> import brainpy.math as bm >>> a = bm.asarray([1, 2, 3.]) >>> a[0] = 10. - Keep sharding constraints during computation. - More dense array operations with PyTorch syntax. """ __slots__ = ('_value',) def __init__(self, value, dtype: Any = None): # array value if isinstance(value, Array): value = value.value elif isinstance(value, (tuple, list, np.ndarray)): value = jnp.asarray(value) elif isinstance(value, jax.Array): pass else: # raw Python scalars (int/float/bool/complex) and any other input: # convert to a jax array so ``self._value`` is always array-like # (mirrors the ``value`` setter). value = jnp.asarray(value) if dtype is not None: value = jnp.asarray(value, dtype=dtype) self._value = value def __repr__(self) -> str: print_code = repr(self.value) if ', dtype' in print_code: print_code = print_code.split(', dtype')[0] + ')' prefix = f'{self.__class__.__name__}' prefix2 = f'{self.__class__.__name__}(value=' if '\n' in print_code: lines = print_code.split("\n") blank1 = " " * len(prefix2) lines[0] = prefix2 + lines[0] for i in range(1, len(lines)): lines[i] = blank1 + lines[i] lines[-1] += "," blank2 = " " * (len(prefix) + 1) lines.append(f'{blank2}dtype={self.dtype})') print_code = "\n".join(lines) else: print_code = prefix2 + print_code + f', dtype={self.dtype})' return print_code def tree_flatten(self): return (self.value,), None @classmethod def tree_unflatten(cls, aux_data, flat_contents): # Reconstruct without going through ``__init__``: during abstract # evaluation (``jax.eval_shape``, ``scan``/``for_loop`` tracing) the leaf # is a ``ShapedArray``/``ShapeDtypeStruct`` rather than a concrete array, # and ``__init__`` would try to ``jnp.asarray`` it and raise. Storing the # leaf directly keeps the pytree round-trip transparent. ins = object.__new__(cls) ins._value = flat_contents[0] return ins @property def data(self): return self.value @data.setter def data(self, value): self.value = value @property def value(self): # return the value return self._value @value.setter def value(self, value): self_value = self._value if isinstance(value, Array): value = value.value elif isinstance(value, np.ndarray): value = jnp.asarray(value) elif isinstance(value, jax.Array): pass else: value = jnp.asarray(value) # # check # if value.shape != self_value.shape: # raise MathError(f"The shape of the original data is {self_value.shape}, " # f"while we got {value.shape}.") # if value.dtype != self_value.dtype: # raise MathError(f"The dtype of the original data is {self_value.dtype}, " # f"while we got {value.dtype}.") self._value = value
[docs] def update(self, value): """Update the value of this Array. """ self.value = value
def __array__(self, dtype=None): """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" return np.asarray(self.value, dtype=dtype) def __jax_array__(self): return self.value
[docs] def as_variable(self): """As an instance of Variable.""" global bm if bm is None: from brainpy import math as bm return bm.Variable(self)
# ----------------------- # # JAX methods # # ----------------------- # @property def at(self): return self.value.at def block_host_until_ready(self, *args): # ``jax.Array.block_host_until_ready`` was removed; ``block_until_ready`` # is the modern equivalent. return self.value.block_until_ready(*args) def block_until_ready(self, *args): return self.value.block_until_ready(*args) @property def device(self): # ``jax.Array.device`` is now a property (it used to be a method). return self.value.device @property def device_buffer(self): # ``jax.Array.device_buffer`` was removed; the addressable shard's data # is the modern equivalent on a single-device array. return self.value.addressable_data(0)
[docs] def fill_(self, fill_value): """Fill the array with a scalar value. Parameters ---------- fill_value the scalar value to fill the array. """ if isinstance(fill_value, Array): fill_value = fill_value.value elif isinstance(fill_value, np.ndarray): fill_value = jnp.asarray(fill_value) elif isinstance(fill_value, jax.Array): pass else: fill_value = jnp.asarray(fill_value) # check if fill_value.shape != (): raise MathError(f"The shape of the fill value must be (), " f"while we got {fill_value.shape}.") self.value = jnp.full(self.shape, fill_value, dtype=self.dtype) return self
setattr(Array, "__array_priority__", 100) JaxArray = Array ndarray = Array
[docs] @register_pytree_node_class class ShardedArray(Array): """The sharded array, which stores data across multiple devices. A drawback of sharding is that the data may not be evenly distributed on shards. Parameters ---------- value the array value. dtype : Any the array type. keep_sharding : bool keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True. """ __slots__ = ('_value', '_keep_sharding') def __init__(self, value, dtype: Any = None, *, keep_sharding: bool = True): super().__init__(value, dtype) self._keep_sharding = keep_sharding def tree_flatten(self): # Carry ``_keep_sharding`` in ``aux_data`` so it survives a pytree # round-trip (``jit``/``vmap``/``scan``/``grad``). Flatten the *raw* # ``_value`` rather than the ``value`` property: the property inserts a # ``with_sharding_constraint``, which must not run during the abstract # flatten step (the leaf may be a tracer/``ShapeDtypeStruct``). return (self._value,), self._keep_sharding @classmethod def tree_unflatten(cls, aux_data, flat_contents): # Reconstruct without ``__init__`` (the leaf may be abstract during # tracing) and restore ``_keep_sharding`` from ``aux_data``; otherwise # the ``value`` getter raises ``AttributeError`` after any transform. ins = object.__new__(cls) ins._value = flat_contents[0] ins._keep_sharding = True if aux_data is None else aux_data return ins @property def value(self): """The value stored in this array. Returns ------- The stored data. """ v = self._value # Keep sharding constraints, but only for genuinely multi-device # shardings. A ``SingleDeviceSharding`` (the default on a single device, # e.g. CPU) carries no distribution information, so inserting a # ``with_sharding_constraint`` on every read is pure overhead. if ( self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None) and not isinstance(v.sharding, jax.sharding.SingleDeviceSharding) ): return jax.lax.with_sharding_constraint(v, v.sharding) # return the value return v @value.setter def value(self, value): self_value = self._value if isinstance(value, Array): value = value.value elif isinstance(value, np.ndarray): value = jnp.asarray(value) elif isinstance(value, jax.Array): pass else: value = jnp.asarray(value) # check if value.shape != self_value.shape: raise MathError(f"The shape of the original data is {self_value.shape}, " f"while we got {value.shape}.") if value.dtype != self_value.dtype: raise MathError(f"The dtype of the original data is {self_value.dtype}, " f"while we got {value.dtype}.") self._value = value