# brainpy.math.jaxarray.JaxArray#

class brainpy.math.jaxarray.JaxArray(value)[source]#

Multiple-dimensional array in JAX backend.

__init__(value)[source]#

Methods

 `__init__`(value) `all`([axis, keepdims]) Returns True if all elements evaluate to True. `any`([axis, keepdims]) Returns True if any of the elements of a evaluate to True. `argmax`([axis]) Return indices of the maximum values along the given axis. `argmin`([axis]) Return indices of the minimum values along the given axis. `argpartition`(kth[, axis, kind, order]) Returns the indices that would partition this array. `argsort`([axis, kind, order]) Returns the indices that would sort this array. `astype`(dtype) Copy of the array, cast to a specified type. `block_host_until_ready`(*args) `block_until_ready`(*args) `byteswap`([inplace]) Swap the bytes of the array elements `choose`(choices[, mode]) Use an index array to construct a new array from a set of choices. `clip`([min, max]) Return an array whose values are limited to [min, max]. `compress`(condition[, axis]) Return selected slices of this array along given axis. `conj`() Complex-conjugate all elements. `conjugate`() Return the complex conjugate, element-wise. `copy`() Return a copy of the array. `cumprod`([axis, dtype]) Return the cumulative product of the elements along the given axis. `cumsum`([axis, dtype]) Return the cumulative sum of the elements along the given axis. `diagonal`([offset, axis1, axis2]) Return specified diagonals. `dot`(b) Dot product of two arrays. `fill`(value) Fill the array with a scalar value. `flatten`([order]) `item`(*args) Copy an element of an array to a standard Python scalar and return it. `max`([axis, keepdims]) Return the maximum along a given axis. `mean`([axis, dtype, keepdims]) Returns the average of the array elements along given axis. `min`([axis, keepdims]) Return the minimum along a given axis. `nonzero`() Return the indices of the elements that are non-zero. `numpy`([dtype]) Convert to numpy.ndarray. `prod`([axis, dtype, keepdims, initial, where]) Return the product of the array elements over the given axis. `ptp`([axis, keepdims]) Peak to peak (maximum - minimum) value along a given axis. `put`(indices, values) Replaces specified elements of an array with given values. `ravel`([order]) Return a flattened array. `repeat`(repeats[, axis]) Repeat elements of an array. `reshape`(*shape[, order]) Returns an array containing the same data with a new shape. `resize`(new_shape) Change shape and size of array in-place. `round`([decimals]) Return `a` with each element rounded to the given number of decimals. `searchsorted`(v[, side, sorter]) Find indices where elements should be inserted to maintain order. `sort`([axis, kind, order]) Sort an array in-place. `split`(indices_or_sections[, axis]) Split an array into multiple sub-arrays as views into `ary`. `squeeze`([axis]) Remove axes of length one from `a`. `std`([axis, dtype, ddof, keepdims]) Compute the standard deviation along the specified axis. `sum`([axis, dtype, keepdims, initial, where]) Return the sum of the array elements over the given axis. `swapaxes`(axis1, axis2) Return a view of the array with axis1 and axis2 interchanged. `take`(indices[, axis, mode]) Return an array formed from the elements of a at the given indices. `tile`(reps) Construct an array by repeating A the number of times given by reps. `to_jax`([dtype]) Convert to jax.numpy.ndarray. `to_numpy`([dtype]) Convert to numpy.ndarray. `tobytes`([order]) Construct Python bytes containing the raw data bytes in the array. `tolist`() Return the array as an `a.ndim`-levels deep nested list of Python scalars. `trace`([offset, axis1, axis2, dtype]) Return the sum along diagonals of the array. `transpose`(*axes) Returns a view of the array with axes transposed. `update`(value) Update the value of this JaxArray. `var`([axis, dtype, ddof, keepdims]) Returns the variance of the array elements, along given axis. `view`([dtype]) New view of array with the same data.

Attributes

 `T` `at` `dtype` `imag` `ndim` `real` `shape` `size` `value`