brainpy.math.jaxarray.VariableView#

class brainpy.math.jaxarray.VariableView(value, index)[source]#

A view of a Variable instance.

This class is used to create a subset view of brainpy.math.Variable.

>>> import brainpy.math as bm
>>> bm.random.seed(123)
>>> origin = bm.Variable(bm.random.random(5))
>>> view = bm.VariableView(origin, slice(None, 2, None))  # origin[:2]
VariableView([0.02920651, 0.19066381], dtype=float32)

VariableView can be used to update the subset of the original Variable instance, and make operations on this subset of the Variable.

>>> view[:] = 1.
>>> view
VariableView([1., 1.], dtype=float32)
>>> origin
Variable([1.       , 1.       , 0.5482849, 0.6564884, 0.8446237], dtype=float32)
>>> view + 10
DeviceArray([11., 11.], dtype=float32)
>>> view *= 10
VariableView([10., 10.], dtype=float32)

The above example demonstrates that the updating of an VariableView instance is actually made in the original Variable instance.

Moreover, it’s worthy to note that VariableView is not a PyTree.

__init__(value, index)[source]#

Methods

__init__(value, index)

all([axis, keepdims])

Returns True if all elements evaluate to True.

any([axis, keepdims])

Returns True if any of the elements of a evaluate to True.

argmax([axis])

Return indices of the maximum values along the given axis.

argmin([axis])

Return indices of the minimum values along the given axis.

argpartition(kth[, axis, kind, order])

Returns the indices that would partition this array.

argsort([axis, kind, order])

Returns the indices that would sort this array.

astype(dtype)

Copy of the array, cast to a specified type.

block_host_until_ready(*args)

block_until_ready(*args)

byteswap([inplace])

Swap the bytes of the array elements

choose(choices[, mode])

Use an index array to construct a new array from a set of choices.

clip([min, max])

Return an array whose values are limited to [min, max].

compress(condition[, axis])

Return selected slices of this array along given axis.

conj()

Complex-conjugate all elements.

conjugate()

Return the complex conjugate, element-wise.

copy()

Return a copy of the array.

cumprod([axis, dtype])

Return the cumulative product of the elements along the given axis.

cumsum([axis, dtype])

Return the cumulative sum of the elements along the given axis.

device()

diagonal([offset, axis1, axis2])

Return specified diagonals.

dot(b)

Dot product of two arrays.

fill(value)

Fill the array with a scalar value.

flatten([order])

item(*args)

Copy an element of an array to a standard Python scalar and return it.

max([axis, keepdims])

Return the maximum along a given axis.

mean([axis, dtype, keepdims])

Returns the average of the array elements along given axis.

min([axis, keepdims])

Return the minimum along a given axis.

nonzero()

Return the indices of the elements that are non-zero.

numpy([dtype])

Convert to numpy.ndarray.

prod([axis, dtype, keepdims, initial, where])

Return the product of the array elements over the given axis.

ptp([axis, keepdims])

Peak to peak (maximum - minimum) value along a given axis.

put(indices, values)

Replaces specified elements of an array with given values.

ravel([order])

Return a flattened array.

repeat(repeats[, axis])

Repeat elements of an array.

reshape(*shape[, order])

Returns an array containing the same data with a new shape.

resize(new_shape)

Change shape and size of array in-place.

round([decimals])

Return a with each element rounded to the given number of decimals.

searchsorted(v[, side, sorter])

Find indices where elements should be inserted to maintain order.

sort([axis, kind, order])

Sort an array in-place.

split(indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

squeeze([axis])

Remove axes of length one from a.

std([axis, dtype, ddof, keepdims])

Compute the standard deviation along the specified axis.

sum([axis, dtype, keepdims, initial, where])

Return the sum of the array elements over the given axis.

swapaxes(axis1, axis2)

Return a view of the array with axis1 and axis2 interchanged.

take(indices[, axis, mode])

Return an array formed from the elements of a at the given indices.

tile(reps)

Construct an array by repeating A the number of times given by reps.

to_jax([dtype])

Convert to jax.numpy.ndarray.

to_numpy([dtype])

Convert to numpy.ndarray.

tobytes([order])

Construct Python bytes containing the raw data bytes in the array.

tolist()

Return the array as an a.ndim-levels deep nested list of Python scalars.

trace([offset, axis1, axis2, dtype])

Return the sum along diagonals of the array.

transpose(*axes)

Returns a view of the array with axes transposed.

update(value)

Update the value of this JaxArray.

var([axis, dtype, ddof, keepdims])

Returns the variance of the array elements, along given axis.

view([dtype])

New view of array with the same data.

Attributes

T

at

batch_axis

rtype

Optional[int]

batch_size

rtype

Optional[int]

device_buffer

dtype

Variable dtype.

imag

ndim

real

shape

Variable shape.

shape_nb

Shape without batch axis.

size

value