brainpy.math.jaxarray.JaxArray
brainpy.math.jaxarray.JaxArray#
- class brainpy.math.jaxarray.JaxArray(value)[source]#
Multiple-dimensional array in JAX backend.
Methods
__init__
(value)all
([axis, keepdims])Returns True if all elements evaluate to True.
any
([axis, keepdims])Returns True if any of the elements of a evaluate to True.
argmax
([axis])Return indices of the maximum values along the given axis.
argmin
([axis])Return indices of the minimum values along the given axis.
argpartition
(kth[, axis, kind, order])Returns the indices that would partition this array.
argsort
([axis, kind, order])Returns the indices that would sort this array.
astype
(dtype)Copy of the array, cast to a specified type.
block_host_until_ready
(*args)block_until_ready
(*args)byteswap
([inplace])Swap the bytes of the array elements
choose
(choices[, mode])Use an index array to construct a new array from a set of choices.
clip
([min, max])Return an array whose values are limited to [min, max].
compress
(condition[, axis])Return selected slices of this array along given axis.
conj
()Complex-conjugate all elements.
conjugate
()Return the complex conjugate, element-wise.
copy
()Return a copy of the array.
cumprod
([axis, dtype])Return the cumulative product of the elements along the given axis.
cumsum
([axis, dtype])Return the cumulative sum of the elements along the given axis.
diagonal
([offset, axis1, axis2])Return specified diagonals.
dot
(b)Dot product of two arrays.
fill
(value)Fill the array with a scalar value.
flatten
([order])item
(*args)Copy an element of an array to a standard Python scalar and return it.
max
([axis, keepdims])Return the maximum along a given axis.
mean
([axis, dtype, keepdims])Returns the average of the array elements along given axis.
min
([axis, keepdims])Return the minimum along a given axis.
nonzero
()Return the indices of the elements that are non-zero.
numpy
([dtype])Convert to numpy.ndarray.
prod
([axis, dtype, keepdims, initial, where])Return the product of the array elements over the given axis.
ptp
([axis, keepdims])Peak to peak (maximum - minimum) value along a given axis.
put
(indices, values)Replaces specified elements of an array with given values.
ravel
([order])Return a flattened array.
repeat
(repeats[, axis])Repeat elements of an array.
reshape
(*shape[, order])Returns an array containing the same data with a new shape.
resize
(new_shape)Change shape and size of array in-place.
round
([decimals])Return
a
with each element rounded to the given number of decimals.searchsorted
(v[, side, sorter])Find indices where elements should be inserted to maintain order.
sort
([axis, kind, order])Sort an array in-place.
split
(indices_or_sections[, axis])Split an array into multiple sub-arrays as views into
ary
.squeeze
([axis])Remove axes of length one from
a
.std
([axis, dtype, ddof, keepdims])Compute the standard deviation along the specified axis.
sum
([axis, dtype, keepdims, initial, where])Return the sum of the array elements over the given axis.
swapaxes
(axis1, axis2)Return a view of the array with axis1 and axis2 interchanged.
take
(indices[, axis, mode])Return an array formed from the elements of a at the given indices.
tile
(reps)Construct an array by repeating A the number of times given by reps.
to_jax
([dtype])Convert to jax.numpy.ndarray.
to_numpy
([dtype])Convert to numpy.ndarray.
tobytes
([order])Construct Python bytes containing the raw data bytes in the array.
tolist
()Return the array as an
a.ndim
-levels deep nested list of Python scalars.trace
([offset, axis1, axis2, dtype])Return the sum along diagonals of the array.
transpose
(*axes)Returns a view of the array with axes transposed.
update
(value)Update the value of this JaxArray.
var
([axis, dtype, ddof, keepdims])Returns the variance of the array elements, along given axis.
view
([dtype])New view of array with the same data.
Attributes
T
at
dtype
imag
ndim
real
shape
size
value