# -*- coding: utf-8 -*-
import warnings
import numpy as np
from jax import numpy as jnp
from jax.tree_util import register_pytree_node
from brainpy.errors import MathError
__all__ = [
'JaxArray',
'ndarray', # alias of JaxArray
'Variable',
'TrainVar',
'Parameter',
]
# Ways to change values in a zero-dimensional array
# -----
# Reference: https://stackoverflow.com/questions/56954714/how-do-i-assign-to-a-zero-dimensional-numpy-array
#
# >>> x = np.array(10)
# 1. index the original array with ellipsis or an empty tuple
# >>> x[...] = 2
# >>> x[()] = 2
_all_slice = slice(None, None, None)
msg = ('JaxArray created outside of the jit function '
'cannot be updated in JIT mode. You should '
'mark it as brainpy.math.Variable instead.')
_global_jit_mode = False
def turn_on_global_jit():
"""Turn on the global JIT mode to declare
all instantiated JaxArray cannot be updated."""
global _global_jit_mode
_global_jit_mode = True
def turn_off_global_jit():
"""Turn off the global JIT mode."""
global _global_jit_mode
_global_jit_mode = False
[docs]class JaxArray(object):
"""Multiple-dimensional array in JAX backend.
"""
__slots__ = ("_value", "_outside_global_jit")
[docs] def __init__(self, value, dtype=None):
# array value
if isinstance(value, JaxArray):
value = value._value
elif isinstance(value, (tuple, list, np.ndarray)):
value = jnp.asarray(value)
if dtype is not None:
value = jnp.asarray(value, dtype=dtype)
self._value = value
# jit mode
self._outside_global_jit = False if _global_jit_mode else True
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self.update(value)
def update(self, value):
"""Update the value of this JaxArray.
"""
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
if isinstance(value, JaxArray):
value = value.value
elif isinstance(value, np.ndarray):
value = jnp.asarray(value)
elif isinstance(value, jnp.ndarray):
pass
else:
value = jnp.asarray(value)
# check
if value.shape != self._value.shape:
raise MathError(f"The shape of the original data is {self._value.shape}, "
f"while we got {value.shape}.")
if value.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, JaxArray) else value
@property
def dtype(self):
return self._value.dtype
@property
def shape(self):
return self._value.shape
@property
def ndim(self):
return self._value.ndim
@property
def imag(self):
return self._value.image
@property
def real(self):
return JaxArray(self._value.real)
@property
def size(self):
return self.value.size
@property
def T(self):
return JaxArray(self.value.T)
# ----------------------- #
# Python inherent methods #
# ----------------------- #
def __repr__(self) -> str:
print_code = repr(self.value)
name = self.__class__.__name__
if 'DeviceArray' in print_code:
print_code = print_code.replace('DeviceArray', name)
lines = print_code.split("\n")
if len(name) > len('DeviceArray'):
num_len = len(name) - len('DeviceArray')
for i in range(1, len(lines)):
lines[i] = " " * num_len + lines[i]
else:
num_len = len('DeviceArray') - len(name)
for i in range(1, len(lines)):
lines[i] = lines[i][num_len:]
print_code = "\n".join(lines)
else:
lines = print_code.split("\n")
prefix = name + "("
lines[0] = prefix + lines[0]
prefix = " " * len(prefix)
for i in range(1, len(lines)):
lines[i] = prefix + lines[i]
lines[-1] = lines[-1] + ")"
print_code = "\n".join(lines)
return print_code
def __format__(self, format_spec: str) -> str:
return format(self.value)
def __iter__(self):
"""Solve the issue of DeviceArray.__iter__.
Details please see JAX issues:
- https://github.com/google/jax/issues/7713
- https://github.com/google/jax/pull/3821
"""
for v in self._value:
yield v
def __getitem__(self, index):
if isinstance(index, slice) and (index == _all_slice):
return self.value
elif isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
elif isinstance(index, JaxArray):
index = index.value
return self.value[index]
def __setitem__(self, index, value):
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
# value is JaxArray
if isinstance(value, JaxArray):
value = value.value
# tuple index
if isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
# JaxArray index
elif isinstance(index, JaxArray):
index = index.value
# update
self._value = self._value.at[index].set(value)
# ---------- #
# operations #
# ---------- #
def __bool__(self) -> bool:
return self._value.__bool__()
def __len__(self) -> int:
return len(self._value)
def __neg__(self):
return JaxArray(self._value.__neg__())
def __pos__(self):
return JaxArray(self._value.__pos__())
def __abs__(self):
return JaxArray(self._value.__abs__())
def __invert__(self):
return JaxArray(self._value.__invert__())
def __eq__(self, oc):
return JaxArray(self._value == (oc._value if isinstance(oc, JaxArray) else oc))
def __ne__(self, oc):
return JaxArray(self._value != (oc._value if isinstance(oc, JaxArray) else oc))
def __lt__(self, oc):
return JaxArray(self._value < (oc._value if isinstance(oc, JaxArray) else oc))
def __le__(self, oc):
return JaxArray(self._value <= (oc._value if isinstance(oc, JaxArray) else oc))
def __gt__(self, oc):
return JaxArray(self._value > (oc._value if isinstance(oc, JaxArray) else oc))
def __ge__(self, oc):
return JaxArray(self._value >= (oc._value if isinstance(oc, JaxArray) else oc))
def __add__(self, oc):
return JaxArray(self._value + (oc._value if isinstance(oc, JaxArray) else oc))
def __radd__(self, oc):
return JaxArray(self._value + (oc._value if isinstance(oc, JaxArray) else oc))
def __iadd__(self, oc):
# a += b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value += (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __sub__(self, oc):
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))
def __rsub__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) - self._value)
def __isub__(self, oc):
# a -= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value - (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __mul__(self, oc):
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))
def __rmul__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) * self._value)
def __imul__(self, oc):
# a *= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __rdiv__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)
def __truediv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
def __rtruediv__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)
def __itruediv__(self, oc):
# a /= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value / (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __floordiv__(self, oc):
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))
def __rfloordiv__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) // self._value)
def __ifloordiv__(self, oc):
# a //= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value // (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __divmod__(self, oc):
return JaxArray(self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc))
def __rdivmod__(self, oc):
return JaxArray(self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc))
def __mod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
def __rmod__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) % self._value)
def __imod__(self, oc):
# a %= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value % (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __pow__(self, oc):
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))
def __rpow__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ** self._value)
def __ipow__(self, oc):
# a **= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value ** (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __matmul__(self, oc):
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))
def __rmatmul__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) @ self._value)
def __imatmul__(self, oc):
# a @= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value @ (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __and__(self, oc):
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))
def __rand__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) & self._value)
def __iand__(self, oc):
# a &= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value & (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __or__(self, oc):
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))
def __ror__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) | self._value)
def __ior__(self, oc):
# a |= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value | (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __xor__(self, oc):
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))
def __rxor__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ^ self._value)
def __ixor__(self, oc):
# a ^= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value ^ (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __lshift__(self, oc):
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))
def __rlshift__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) << self._value)
def __ilshift__(self, oc):
# a <<= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value << (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __rshift__(self, oc):
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))
def __rrshift__(self, oc):
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) >> self._value)
def __irshift__(self, oc):
# a >>= b
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self._value >> (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __round__(self, ndigits=None):
return JaxArray(self._value.__round__(ndigits))
# ----------------------- #
# JAX methods #
# ----------------------- #
@property
def at(self):
return self.value.at
def block_host_until_ready(self, *args):
self._value.block_host_until_ready(*args)
def block_until_ready(self, *args):
self._value.block_until_ready(*args)
def device(self):
raise self.value.device()
@property
def device_buffer(self):
raise self.value.device_buffer
# ----------------------- #
# NumPy methods #
# ----------------------- #
def all(self, axis=None, keepdims=False):
"""Returns True if all elements evaluate to True."""
r = self.value.all(axis=axis, keepdims=keepdims)
return r if (axis is None or keepdims) else JaxArray(r)
def any(self, axis=None, keepdims=False):
"""Returns True if any of the elements of a evaluate to True."""
r = self.value.any(axis=axis, keepdims=keepdims)
return r if (axis is None or keepdims) else JaxArray(r)
def argmax(self, axis=None):
"""Return indices of the maximum values along the given axis."""
return JaxArray(self.value.argmax(axis=axis))
def argmin(self, axis=None):
"""Return indices of the minimum values along the given axis."""
return JaxArray(self.value.argmin(axis=axis))
def argpartition(self, kth, axis=-1, kind='introselect', order=None):
"""Returns the indices that would partition this array."""
return JaxArray(self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order))
def argsort(self, axis=-1, kind=None, order=None):
"""Returns the indices that would sort this array."""
return JaxArray(self.value.argsort(axis=axis, kind=kind, order=order))
def astype(self, dtype):
"""Copy of the array, cast to a specified type.
Parameters
----------
dtype: str, dtype
Typecode or data-type to which the array is cast.
"""
return JaxArray(self.value.astype(dtype=dtype))
def byteswap(self, inplace=False):
"""Swap the bytes of the array elements
Toggle between low-endian and big-endian data representation by
returning a byteswapped array, optionally swapped in-place.
Arrays of byte-strings are not swapped. The real and imaginary
parts of a complex number are swapped individually."""
return JaxArray(self.value.byteswap(inplace=inplace))
def choose(self, choices, mode='raise'):
"""Use an index array to construct a new array from a set of choices."""
choices = choices.value if isinstance(choices, JaxArray) else choices
return JaxArray(self.value.choose(choices=choices, mode=mode))
def clip(self, min=None, max=None):
"""Return an array whose values are limited to [min, max]. One of max or min must be given."""
return JaxArray(self.value.clip(min=min, max=max))
def compress(self, condition, axis=None):
"""Return selected slices of this array along given axis."""
condition = condition.value if isinstance(condition, JaxArray) else condition
return JaxArray(self.value.compress(condition=condition, axis=axis))
def conj(self):
"""Complex-conjugate all elements."""
return JaxArray(self.value.conj())
def conjugate(self):
"""Return the complex conjugate, element-wise."""
return JaxArray(self.value.conjugate())
def copy(self):
"""Return a copy of the array."""
return JaxArray(self.value.copy())
def cumprod(self, axis=None, dtype=None):
"""Return the cumulative product of the elements along the given axis."""
return JaxArray(self.value.cumprod(axis=axis, dtype=dtype))
def cumsum(self, axis=None, dtype=None):
"""Return the cumulative sum of the elements along the given axis."""
return JaxArray(self.value.cumsum(axis=axis, dtype=dtype))
def diagonal(self, offset=0, axis1=0, axis2=1):
"""Return specified diagonals."""
return JaxArray(self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2))
def dot(self, b):
"""Dot product of two arrays."""
return JaxArray(self.value.dot(b.value if isinstance(b, JaxArray) else b))
def fill(self, value):
"""Fill the array with a scalar value."""
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = jnp.ones_like(self.value) * value
def flatten(self, order='C'):
return JaxArray(self.value.flatten(order=order))
def item(self, *args):
"""Copy an element of an array to a standard Python scalar and return it."""
return self.value.item(*args)
def max(self, axis=None, keepdims=False, *args, **kwargs):
"""Return the maximum along a given axis."""
res = self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs)
return res if (axis is None or keepdims) else JaxArray(res)
def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs):
"""Returns the average of the array elements along given axis."""
res = self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs)
return res if (axis is None or keepdims) else JaxArray(res)
def min(self, axis=None, keepdims=False, *args, **kwargs):
"""Return the minimum along a given axis."""
res = self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs)
return res if (axis is None or keepdims) else JaxArray(res)
def nonzero(self):
"""Return the indices of the elements that are non-zero."""
return tuple(JaxArray(a) for a in self.value.nonzero())
def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True):
"""Return the product of the array elements over the given axis."""
res = self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
return res if (axis is None or keepdims) else JaxArray(res)
def ptp(self, axis=None, keepdims=False):
"""Peak to peak (maximum - minimum) value along a given axis."""
r = self.value.ptp(axis=axis, keepdims=keepdims)
return r if (axis is None or keepdims) else JaxArray(r)
def put(self, indices, values):
"""Replaces specified elements of an array with given values.
Parameters
----------
indices: array_like
Target indices, interpreted as integers.
values: array_like
Values to place in the array at target indices.
"""
self.__setitem__(indices, values)
def ravel(self, order=None):
"""Return a flattened array."""
return JaxArray(self.value.ravel(order=order))
def repeat(self, repeats, axis=None):
"""Repeat elements of an array."""
return JaxArray(self.value.repeat(repeats=repeats, axis=axis))
def reshape(self, *shape, order='C'):
"""Returns an array containing the same data with a new shape."""
return JaxArray(self.value.reshape(*shape, order=order))
def resize(self, new_shape):
"""Change shape and size of array in-place."""
self._value = self.value.reshape(new_shape)
def round(self, decimals=0):
"""Return ``a`` with each element rounded to the given number of decimals."""
return JaxArray(self.value.round(decimals=decimals))
def searchsorted(self, v, side='left', sorter=None):
"""Find indices where elements should be inserted to maintain order.
Find the indices into a sorted array `a` such that, if the
corresponding elements in `v` were inserted before the indices, the
order of `a` would be preserved.
Assuming that `a` is sorted:
====== ============================
`side` returned index `i` satisfies
====== ============================
left ``a[i-1] < v <= a[i]``
right ``a[i-1] <= v < a[i]``
====== ============================
Parameters
----------
v : array_like
Values to insert into `a`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
sorter : 1-D array_like, optional
Optional array of integer indices that sort array a into ascending
order. They are typically the result of argsort.
Returns
-------
indices : array of ints
Array of insertion points with the same shape as `v`.
"""
v = v.value if isinstance(v, JaxArray) else v
return JaxArray(self.value.searchsorted(v=v, side=side, sorter=sorter))
def sort(self, axis=-1, kind='quicksort', order=None):
"""Sort an array in-place.
Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
and 'mergesort' use timsort under the covers and, in general, the
actual implementation will vary with datatype. The 'mergesort' option
is retained for backwards compatibility.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
be specified as a string, and not all fields need be specified,
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
if self._outside_global_jit and _global_jit_mode:
raise MathError(msg)
self._value = self.value.sort(axis=axis, kind=kind, order=order)
def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
return JaxArray(self.value.squeeze(axis=axis))
def std(self, axis=None, dtype=None, ddof=0, keepdims=False):
"""Compute the standard deviation along the specified axis.
Returns the standard deviation, a measure of the spread of a distribution,
of the array elements. The standard deviation is computed for the
flattened array by default, otherwise over the specified axis.
Parameters
----------
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviation is computed. The
default is to compute the standard deviation of the flattened array.
If this is a tuple of ints, a standard deviation is performed over
multiple axes, instead of a single axis or all the axes as before.
dtype : dtype, optional
Type to use in computing the standard deviation. For arrays of
integer type the default is float64, for arrays of float types it is
the same as the array type.
ddof : int, optional
Means Delta Degrees of Freedom. The divisor used in calculations
is ``N - ddof``, where ``N`` represents the number of elements.
By default `ddof` is zero.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `std` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
Returns
-------
standard_deviation : ndarray, see dtype parameter above.
If `out` is None, return a new array containing the standard deviation,
otherwise return a reference to the output array.
"""
r = self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)
return r if (axis is None or keepdims) else JaxArray(r)
def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True):
"""Return the sum of the array elements over the given axis."""
res = self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
return res if (axis is None or keepdims) else JaxArray(res)
def swapaxes(self, axis1, axis2):
"""Return a view of the array with `axis1` and `axis2` interchanged."""
return JaxArray(self.value.swapaxes(axis1, axis2))
def split(self, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays as views into ``ary``.
Parameters
----------
indices_or_sections : int, 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays as views into `ary`.
"""
return [JaxArray(a) for a in self.value.split(indices_or_sections, axis=axis)]
def take(self, indices, axis=None, mode=None):
"""Return an array formed from the elements of a at the given indices."""
indices = indices.value if isinstance(indices, JaxArray) else indices
return JaxArray(self.value.take(indices=indices, axis=axis, mode=mode))
def tobytes(self, order='C'):
"""Construct Python bytes containing the raw data bytes in the array.
Constructs Python bytes showing a copy of the raw contents of data memory.
The bytes object is produced in C-order by default. This behavior is
controlled by the ``order`` parameter."""
return JaxArray(self.value.tobytes(order=order))
def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Return a copy of the array data as a (nested) Python list.
Data items are converted to the nearest compatible builtin Python type, via
the `~numpy.ndarray.item` function.
If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will
not be a list at all, but a simple Python scalar.
"""
return self.value.tolist()
def trace(self, offset=0, axis1=0, axis2=1, dtype=None):
"""Return the sum along diagonals of the array."""
return JaxArray(self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype))
def transpose(self, *axes):
"""Returns a view of the array with axes transposed.
For a 1-D array this has no effect, as a transposed vector is simply the
same vector. To convert a 1-D array into a 2D column vector, an additional
dimension must be added. `np.atleast2d(a).T` achieves this, as does
`a[:, np.newaxis]`.
For a 2-D array, this is a standard matrix transpose.
For an n-D array, if axes are given, their order indicates how the
axes are permuted (see Examples). If axes are not provided and
``a.shape = (i[0], i[1], ... i[n-2], i[n-1])``, then
``a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0])``.
Parameters
----------
axes : None, tuple of ints, or `n` ints
* None or no argument: reverses the order of the axes.
* tuple of ints: `i` in the `j`-th place in the tuple means `a`'s
`i`-th axis becomes `a.transpose()`'s `j`-th axis.
* `n` ints: same as an n-tuple of the same ints (this form is
intended simply as a "convenience" alternative to the tuple form)
Returns
-------
out : ndarray
View of `a`, with axes suitably permuted.
"""
return JaxArray(self.value.transpose(*axes))
def tile(self, reps):
"""Construct an array by repeating A the number of times given by reps.
If `reps` has length ``d``, the result will have dimension of
``max(d, A.ndim)``.
If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
or shape (1, 1, 3) for 3-D replication. If this is not the desired
behavior, promote `A` to d-dimensions manually before calling this
function.
If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
(1, 1, 2, 2).
Note : Although tile may be used for broadcasting, it is strongly
recommended to use numpy's broadcasting operations and functions.
Parameters
----------
reps : array_like
The number of repetitions of `A` along each axis.
Returns
-------
c : ndarray
The tiled output array.
"""
reps = reps.value if isinstance(reps, JaxArray) else reps
return JaxArray(self.value.tile(reps))
def var(self, axis=None, dtype=None, ddof=0, keepdims=False):
"""Returns the variance of the array elements, along given axis."""
r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)
return r if (axis is None or keepdims) else JaxArray(r)
def view(self, dtype=None, *args, **kwargs):
"""New view of array with the same data."""
return JaxArray(self.value.view(dtype=dtype, *args, **kwargs))
# ------------------
# NumPy support
# ------------------
def numpy(self, dtype=None):
"""Convert to numpy.ndarray."""
warnings.warn('Deprecated since 2.1.12. Please use ".to_numpy()" instead.', DeprecationWarning)
return np.asarray(self.value, dtype=dtype)
def to_numpy(self, dtype=None):
"""Convert to numpy.ndarray."""
return np.asarray(self.value, dtype=dtype)
def to_jax(self, dtype=None):
"""Convert to jax.numpy.ndarray."""
if dtype is None:
return self.value
else:
return jnp.asarray(self.value, dtype=dtype)
def __array__(self, dtype=None):
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value, dtype=dtype)
def __jax_array__(self):
return self.value
ndarray = JaxArray
[docs]class Variable(JaxArray):
"""The pointer to specify the dynamical variable.
"""
__slots__ = ('_value', '_batch_axis')
[docs] def __init__(self, value, dtype=None, batch_axis: int = None):
super(Variable, self).__init__(value, dtype=dtype)
# check batch axis
if isinstance(value, Variable):
if value.batch_axis is not None and batch_axis is not None:
if batch_axis != value.batch_axis:
raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value '
f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}')
batch_axis = value.batch_axis
# assign batch axis
self._batch_axis = batch_axis
if batch_axis is not None:
if batch_axis >= self.ndim:
raise MathError(f'This variables has {self.ndim} dimension, '
f'but the batch axis is set to be {batch_axis}.')
@property
def batch_axis(self):
return self._batch_axis
@batch_axis.setter
def batch_axis(self, val):
raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.')
@property
def batch_size(self):
return self.ndim[self._batch_axis]
@batch_size.setter
def batch_size(self, val):
raise ValueError(f'Cannot set "batch_size" manually.')
def update(self, value):
"""Update the value of this JaxArray.
"""
if self._batch_axis is None:
ext_shape = value.shape
int_shape = self._value.shape
else:
ext_shape = value.shape[:self._batch_axis] + value.shape[self._batch_axis + 1:]
int_shape = self._value.shape[:self._batch_axis] + self._value.shape[self._batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {self._value.shape}, while we got {value.shape}"
if self._batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self._batch_axis}.'
raise MathError(error)
if value.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, JaxArray) else value
def __setitem__(self, index, value):
# value is JaxArray
if isinstance(value, JaxArray):
value = value.value
# tuple index
if isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
# JaxArray index
elif isinstance(index, JaxArray):
index = index.value
# update
self._value = self._value.at[index].set(value)
def __iadd__(self, oc):
# a += b
self._value += (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __isub__(self, oc):
# a -= b
self._value = self._value - (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __imul__(self, oc):
# a *= b
self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __itruediv__(self, oc):
# a /= b
self._value = self._value / (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __ifloordiv__(self, oc):
# a //= b
self._value = self._value // (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __imod__(self, oc):
# a %= b
self._value = self._value % (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __ipow__(self, oc):
# a **= b
self._value = self._value ** (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __imatmul__(self, oc):
# a @= b
self._value = self._value @ (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __iand__(self, oc):
# a &= b
self._value = self._value.__and__(oc._value if isinstance(oc, JaxArray) else oc)
return self
def __ior__(self, oc):
# a |= b
self._value = self._value | (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __ixor__(self, oc):
# a ^= b
self._value = self._value ^ (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __ilshift__(self, oc):
# a <<= b
self._value = self._value << (oc._value if isinstance(oc, JaxArray) else oc)
return self
def __irshift__(self, oc):
# a >>= b
self._value = self._value >> (oc._value if isinstance(oc, JaxArray) else oc)
return self
def fill(self, value):
"""Fill the array with a scalar value."""
self._value = jnp.ones_like(self.value) * value
def sort(self, axis=-1, kind=None, order=None):
"""Sort an array in-place."""
self._value = self.value.sort(axis=axis, kind=kind, order=order)
# ---------- #
# operations #
# ---------- #
def __bool__(self) -> bool:
return self._value.__bool__()
def __len__(self) -> int:
return len(self._value)
def __neg__(self):
return self._value.__neg__()
def __pos__(self):
return self._value.__pos__()
def __abs__(self):
return self._value.__abs__()
def __invert__(self):
return self._value.__invert__()
def __eq__(self, oc):
return self._value == (oc._value if isinstance(oc, JaxArray) else oc)
def __ne__(self, oc):
return self._value != (oc._value if isinstance(oc, JaxArray) else oc)
def __lt__(self, oc):
return self._value < (oc._value if isinstance(oc, JaxArray) else oc)
def __le__(self, oc):
return self._value <= (oc._value if isinstance(oc, JaxArray) else oc)
def __gt__(self, oc):
return self._value > (oc._value if isinstance(oc, JaxArray) else oc)
def __ge__(self, oc):
return self._value >= (oc._value if isinstance(oc, JaxArray) else oc)
def __add__(self, oc):
return self._value + (oc._value if isinstance(oc, JaxArray) else oc)
def __radd__(self, oc):
return self._value + (oc._value if isinstance(oc, JaxArray) else oc)
def __sub__(self, oc):
return self._value - (oc._value if isinstance(oc, JaxArray) else oc)
def __rsub__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) - self._value
def __mul__(self, oc):
return self._value * (oc._value if isinstance(oc, JaxArray) else oc)
def __rmul__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) * self._value
def __rdiv__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) / self._value
def __truediv__(self, oc):
return self._value / (oc._value if isinstance(oc, JaxArray) else oc)
def __rtruediv__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) / self._value
def __floordiv__(self, oc):
return self._value // (oc._value if isinstance(oc, JaxArray) else oc)
def __rfloordiv__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) // self._value
def __divmod__(self, oc):
return self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc)
def __rdivmod__(self, oc):
return self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc)
def __mod__(self, oc):
return self._value % (oc._value if isinstance(oc, JaxArray) else oc)
def __rmod__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) % self._value
def __pow__(self, oc):
return self._value ** (oc._value if isinstance(oc, JaxArray) else oc)
def __rpow__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) ** self._value
def __matmul__(self, oc):
return self._value @ (oc._value if isinstance(oc, JaxArray) else oc)
def __rmatmul__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) @ self._value
def __and__(self, oc):
return self._value & (oc._value if isinstance(oc, JaxArray) else oc)
def __rand__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) & self._value
def __or__(self, oc):
return self._value | (oc._value if isinstance(oc, JaxArray) else oc)
def __ror__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) | self._value
def __xor__(self, oc):
return self._value ^ (oc._value if isinstance(oc, JaxArray) else oc)
def __rxor__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) ^ self._value
def __lshift__(self, oc):
return self._value << (oc._value if isinstance(oc, JaxArray) else oc)
def __rlshift__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) << self._value
def __rshift__(self, oc):
return self._value >> (oc._value if isinstance(oc, JaxArray) else oc)
def __rrshift__(self, oc):
return (oc._value if isinstance(oc, JaxArray) else oc) >> self._value
def __round__(self, ndigits=None):
return self._value.__round__(ndigits)
# ----------------------- #
# NumPy methods #
# ----------------------- #
def all(self, axis=None, keepdims=False):
"""Returns True if all elements evaluate to True."""
return self.value.all(axis=axis, keepdims=keepdims)
def any(self, axis=None, keepdims=False):
"""Returns True if any of the elements of a evaluate to True."""
return self.value.any(axis=axis, keepdims=keepdims)
def argmax(self, axis=None):
"""Return indices of the maximum values along the given axis."""
return self.value.argmax(axis=axis)
def argmin(self, axis=None):
"""Return indices of the minimum values along the given axis."""
return self.value.argmin(axis=axis)
def argpartition(self, kth, axis=-1, kind='introselect', order=None):
"""Returns the indices that would partition this array."""
return self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order)
def argsort(self, axis=-1, kind=None, order=None):
"""Returns the indices that would sort this array."""
return self.value.argsort(axis=axis, kind=kind, order=order)
def astype(self, dtype):
"""Copy of the array, cast to a specified type.
Parameters
----------
dtype: str, dtype
Typecode or data-type to which the array is cast.
"""
return self.value.astype(dtype=dtype)
def byteswap(self, inplace=False):
"""Swap the bytes of the array elements
Toggle between low-endian and big-endian data representation by
returning a byteswapped array, optionally swapped in-place.
Arrays of byte-strings are not swapped. The real and imaginary
parts of a complex number are swapped individually."""
return self.value.byteswap(inplace=inplace)
def choose(self, choices, mode='raise'):
"""Use an index array to construct a new array from a set of choices."""
choices = choices.value if isinstance(choices, JaxArray) else choices
return self.value.choose(choices=choices, mode=mode)
def clip(self, min=None, max=None):
"""Return an array whose values are limited to [min, max]. One of max or min must be given."""
return self.value.clip(min=min, max=max)
def compress(self, condition, axis=None):
"""Return selected slices of this array along given axis."""
condition = condition.value if isinstance(condition, JaxArray) else condition
return self.value.compress(condition=condition, axis=axis)
def conj(self):
"""Complex-conjugate all elements."""
return self.value.conj()
def conjugate(self):
"""Return the complex conjugate, element-wise."""
return self.value.conjugate()
def copy(self):
"""Return a copy of the array."""
return self.value.copy()
def cumprod(self, axis=None, dtype=None):
"""Return the cumulative product of the elements along the given axis."""
return self.value.cumprod(axis=axis, dtype=dtype)
def cumsum(self, axis=None, dtype=None):
"""Return the cumulative sum of the elements along the given axis."""
return self.value.cumsum(axis=axis, dtype=dtype)
def diagonal(self, offset=0, axis1=0, axis2=1):
"""Return specified diagonals."""
return self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2)
def dot(self, b):
"""Dot product of two arrays."""
return self.value.dot(b.value if isinstance(b, JaxArray) else b)
def flatten(self, order='C'):
return self.value.flatten(order=order)
def item(self, *args):
"""Copy an element of an array to a standard Python scalar and return it."""
return self.value.item(*args)
def max(self, axis=None, keepdims=False, *args, **kwargs):
"""Return the maximum along a given axis."""
return self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs)
def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs):
"""Returns the average of the array elements along given axis."""
return self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs)
def min(self, axis=None, keepdims=False, *args, **kwargs):
"""Return the minimum along a given axis."""
return self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs)
def nonzero(self):
"""Return the indices of the elements that are non-zero."""
return self.value.nonzero()
def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True):
"""Return the product of the array elements over the given axis."""
return self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def ptp(self, axis=None, keepdims=False):
"""Peak to peak (maximum - minimum) value along a given axis."""
return self.value.ptp(axis=axis, keepdims=keepdims)
def ravel(self, order=None):
"""Return a flattened array."""
return self.value.ravel(order=order)
def repeat(self, repeats, axis=None):
"""Repeat elements of an array."""
return self.value.repeat(repeats=repeats, axis=axis)
def reshape(self, *shape, order='C'):
"""Returns an array containing the same data with a new shape."""
return self.value.reshape(*shape, order=order)
def round(self, decimals=0):
"""Return ``a`` with each element rounded to the given number of decimals."""
return self.value.round(decimals=decimals)
def searchsorted(self, v, side='left', sorter=None):
"""Find indices where elements should be inserted to maintain order.
Find the indices into a sorted array `a` such that, if the
corresponding elements in `v` were inserted before the indices, the
order of `a` would be preserved.
Assuming that `a` is sorted:
====== ============================
`side` returned index `i` satisfies
====== ============================
left ``a[i-1] < v <= a[i]``
right ``a[i-1] <= v < a[i]``
====== ============================
Parameters
----------
v : array_like
Values to insert into `a`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
sorter : 1-D array_like, optional
Optional array of integer indices that sort array a into ascending
order. They are typically the result of argsort.
Returns
-------
indices : array of ints
Array of insertion points with the same shape as `v`.
"""
v = v.value if isinstance(v, JaxArray) else v
return self.value.searchsorted(v=v, side=side, sorter=sorter)
def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
return self.value.squeeze(axis=axis)
def std(self, axis=None, dtype=None, ddof=0, keepdims=False):
"""Compute the standard deviation along the specified axis.
Returns the standard deviation, a measure of the spread of a distribution,
of the array elements. The standard deviation is computed for the
flattened array by default, otherwise over the specified axis.
Parameters
----------
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviation is computed. The
default is to compute the standard deviation of the flattened array.
If this is a tuple of ints, a standard deviation is performed over
multiple axes, instead of a single axis or all the axes as before.
dtype : dtype, optional
Type to use in computing the standard deviation. For arrays of
integer type the default is float64, for arrays of float types it is
the same as the array type.
ddof : int, optional
Means Delta Degrees of Freedom. The divisor used in calculations
is ``N - ddof``, where ``N`` represents the number of elements.
By default `ddof` is zero.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `std` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
Returns
-------
standard_deviation : ndarray, see dtype parameter above.
If `out` is None, return a new array containing the standard deviation,
otherwise return a reference to the output array.
"""
return self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)
def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True):
"""Return the sum of the array elements over the given axis."""
return self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def swapaxes(self, axis1, axis2):
"""Return a view of the array with `axis1` and `axis2` interchanged."""
return self.value.swapaxes(axis1, axis2)
def split(self, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays as views into ``ary``.
Parameters
----------
indices_or_sections : int, 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays as views into `ary`.
"""
return [JaxArray(a) for a in self.value.split(indices_or_sections, axis=axis)]
def take(self, indices, axis=None, mode=None):
"""Return an array formed from the elements of a at the given indices."""
indices = indices.value if isinstance(indices, JaxArray) else indices
return self.value.take(indices=indices, axis=axis, mode=mode)
def tobytes(self, order='C'):
"""Construct Python bytes containing the raw data bytes in the array.
Constructs Python bytes showing a copy of the raw contents of data memory.
The bytes object is produced in C-order by default. This behavior is
controlled by the ``order`` parameter."""
return self.value.tobytes(order=order)
def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Return a copy of the array data as a (nested) Python list.
Data items are converted to the nearest compatible builtin Python type, via
the `~numpy.ndarray.item` function.
If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will
not be a list at all, but a simple Python scalar.
"""
return self.value.tolist()
def trace(self, offset=0, axis1=0, axis2=1, dtype=None):
"""Return the sum along diagonals of the array."""
return self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
def transpose(self, *axes):
"""Returns a view of the array with axes transposed.
For a 1-D array this has no effect, as a transposed vector is simply the
same vector. To convert a 1-D array into a 2D column vector, an additional
dimension must be added. `np.atleast2d(a).T` achieves this, as does
`a[:, np.newaxis]`.
For a 2-D array, this is a standard matrix transpose.
For an n-D array, if axes are given, their order indicates how the
axes are permuted (see Examples). If axes are not provided and
``a.shape = (i[0], i[1], ... i[n-2], i[n-1])``, then
``a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0])``.
Parameters
----------
axes : None, tuple of ints, or `n` ints
* None or no argument: reverses the order of the axes.
* tuple of ints: `i` in the `j`-th place in the tuple means `a`'s
`i`-th axis becomes `a.transpose()`'s `j`-th axis.
* `n` ints: same as an n-tuple of the same ints (this form is
intended simply as a "convenience" alternative to the tuple form)
Returns
-------
out : ndarray
View of `a`, with axes suitably permuted.
"""
return self.value.transpose(*axes)
def tile(self, reps):
"""Construct an array by repeating A the number of times given by reps.
If `reps` has length ``d``, the result will have dimension of
``max(d, A.ndim)``.
If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
or shape (1, 1, 3) for 3-D replication. If this is not the desired
behavior, promote `A` to d-dimensions manually before calling this
function.
If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
(1, 1, 2, 2).
Note : Although tile may be used for broadcasting, it is strongly
recommended to use numpy's broadcasting operations and functions.
Parameters
----------
reps : array_like
The number of repetitions of `A` along each axis.
Returns
-------
c : ndarray
The tiled output array.
"""
return self.value.tile(reps.value if isinstance(reps, JaxArray) else reps)
def var(self, axis=None, dtype=None, ddof=0, keepdims=False):
"""Returns the variance of the array elements, along given axis."""
return self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)
def view(self, dtype=None, *args, **kwargs):
"""New view of array with the same data."""
return self.value.view(dtype=dtype, *args, **kwargs)
[docs]class TrainVar(Variable):
"""The pointer to specify the trainable variable.
"""
__slots__ = ('_value', '_batch_axis')
[docs] def __init__(self, value, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
[docs]class Parameter(Variable):
"""The pointer to specify the parameter.
"""
__slots__ = ('_value', '_batch_axis')
[docs] def __init__(self, value, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
register_pytree_node(JaxArray,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: JaxArray(*flat_contents))
register_pytree_node(Variable,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Variable(*flat_contents))
register_pytree_node(TrainVar,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: TrainVar(*flat_contents))
register_pytree_node(Parameter,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Parameter(*flat_contents))