# -*- coding: utf-8 -*-
import warnings
from collections import namedtuple
from functools import partial
from operator import index
from typing import Optional, Union
import jax
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class
from jax._src.array import ArrayImpl
from brainpy.check import jit_error
from .compat_numpy import shape
from .environment import get_int
from .ndarray import Array, _return
from .object_transform.variables import Variable
__all__ = [
'RandomState', 'Generator', 'DEFAULT',
'seed', 'default_rng', 'split_key', 'split_keys',
# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta',
'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto',
'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma',
'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli',
'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f',
'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal',
'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power',
'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min',
'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical',
# pytorch compatibility
'rand_like', 'randint_like', 'randn_like',
]
def _formalize_key(key):
if isinstance(key, int):
return jr.PRNGKey(key)
elif isinstance(key, (Array, jnp.ndarray, np.ndarray)):
if key.dtype != jnp.uint32:
raise TypeError('key must be a int or an array with two uint32.')
if key.size != 2:
raise TypeError('key must be a int or an array with two uint32.')
return jnp.asarray(key)
else:
raise TypeError('key must be a int or an array with two uint32.')
def _size2shape(size):
if size is None:
return ()
elif isinstance(size, (tuple, list)):
return tuple(size)
else:
return (size, )
def _check_shape(name, shape, *param_shapes):
shape = core.as_named_shape(shape)
if param_shapes:
shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
if shape.positional != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
raise ValueError(msg.format(name, shape_, shape))
def _as_jax_array(a):
return a.value if isinstance(a, Array) else a
def _is_python_scalar(x):
if hasattr(x, 'aval'):
return x.aval.weak_type
elif np.ndim(x) == 0:
return True
elif isinstance(x, (bool, int, float, complex)):
return True
else:
return False
python_scalar_dtypes = {
bool: np.dtype('bool'),
int: np.dtype('int64'),
float: np.dtype('float64'),
complex: np.dtype('complex128'),
}
def _dtype(x, *, canonicalize: bool = False):
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
if x is None:
raise ValueError(f"Invalid argument to dtype: {x}.")
elif isinstance(x, type) and x in python_scalar_dtypes:
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
dt = x.dtype
else:
dt = np.result_type(x)
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
def _const(example, val):
if _is_python_scalar(example):
dtype = dtypes.canonicalize_dtype(type(example))
val = dtypes.scalar_type_of(example)(val)
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
else:
dtype = dtypes.canonicalize_dtype(example.dtype)
return np.array(val, dtype)
_tr_params = namedtuple(
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
)
def _get_tr_params(n, p):
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
mu = n * p
spq = jnp.sqrt(mu * (1 - p))
c = mu + 0.5
b = 1.15 + 2.53 * spq
a = -0.0873 + 0.0248 * b + 0.01 * p
alpha = (2.83 + 5.1 / b) * spq
u_r = 0.43
v_r = 0.92 - 4.2 / b
m = jnp.floor((n + 1) * p).astype(n.dtype)
log_p = jnp.log(p)
log1_p = jnp.log1p(-p)
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
def _stirling_approx_tail(k):
precomputed = jnp.array([0.08106146679532726,
0.04134069595540929,
0.02767792568499834,
0.02079067210376509,
0.01664469118982119,
0.01387612882307075,
0.01189670994589177,
0.01041126526197209,
0.009255462182712733,
0.008330563433362871, ])
kp1 = k + 1
kp1sq = (k + 1) ** 2
return jnp.where(k < 10,
precomputed[k],
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
def _binomial_btrs(key, p, n):
"""
Based on the transformed rejection sampling algorithm (BTRS) from the
following reference:
Hormann, "The Generation of Binonmial Random Variates"
(https://core.ac.uk/download/pdf/11007254.pdf)
"""
def _btrs_body_fn(val):
_, key, _, _ = val
key, key_u, key_v = jr.split(key, 3)
u = jr.uniform(key_u)
v = jr.uniform(key_v)
u = u - 0.5
k = jnp.floor(
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
).astype(n.dtype)
return k, key, u, v
def _btrs_cond_fn(val):
def accept_fn(k, u, v):
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
# v <= f(k) * g_grad(u) / alpha
m = tr_params.m
log_p = tr_params.log_p
log1_p = tr_params.log1_p
# See: formula for log(f(k)) at bottom of Page 5.
log_f = (
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
+ tr_params.log_h
)
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
return jnp.log((v * tr_params.alpha) / g) <= log_f
k, key, u, v = val
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
early_reject = (k < 0) | (k > n)
return lax.cond(
early_accept | early_reject,
(),
lambda _: ~early_accept,
(k, u, v),
lambda x: ~accept_fn(*x),
)
tr_params = _get_tr_params(n, p)
ret = lax.while_loop(
_btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
) # use k=-1 initially so that cond_fn returns True
return ret[0]
def _binomial_inversion(key, p, n):
def _binom_inv_body_fn(val):
i, key, geom_acc = val
key, key_u = jr.split(key)
u = jr.uniform(key_u)
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
geom_acc = geom_acc + geom
return i + 1, key, geom_acc
def _binom_inv_cond_fn(val):
i, _, geom_acc = val
return geom_acc <= n
log1_p = jnp.log1p(-p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
return ret[0]
def _binomial_dispatch(key, p, n):
def dispatch(key, p, n):
is_le_mid = p <= 0.5
pq = jnp.where(is_le_mid, p, 1 - p)
mu = n * pq
k = lax.cond(
mu < 10,
(key, pq, n),
lambda x: _binomial_inversion(*x),
(key, pq, n),
lambda x: _binomial_btrs(*x),
)
return jnp.where(is_le_mid, k, n - k)
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
return lax.cond(
cond0 & (p < 1),
(key, p, n),
lambda x: dispatch(*x),
(),
lambda _: jnp.where(cond0, n, 0),
)
@partial(jit, static_argnums=(3,))
def _binomial(key, p, n, shape):
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
# reshape to map over axis 0
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
key = jr.split(key, jnp.size(p))
if jax.default_backend() == "cpu":
ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
else:
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
return jnp.reshape(ret, shape)
@partial(jit, static_argnums=(2,))
def _categorical(key, p, shape):
# this implementation is fast when event shape is small, and slow otherwise
# Ref: https://stackoverflow.com/a/34190035
shape = shape or p.shape[:-1]
s = jnp.cumsum(p, axis=-1)
r = jr.uniform(key, shape=shape + (1,))
return jnp.sum(s < r, axis=-1)
def _scatter_add_one(operand, indices, updates):
return lax.scatter_add(
operand,
indices,
updates,
lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
),
)
def _reshape(x, shape):
if isinstance(x, (int, float, np.ndarray, np.generic)):
return np.reshape(x, shape)
else:
return jnp.reshape(x, shape)
def _promote_shapes(*args, shape=()):
# adapted from lax.lax_numpy
if len(args) < 2 and not shape:
return args
else:
shapes = [jnp.shape(arg) for arg in args]
num_dims = len(lax.broadcast_shapes(shape, *shapes))
return [
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
for arg, s in zip(args, shapes)
]
@partial(jit, static_argnums=(3, 4))
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
if n_max == 0:
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
# get indices from categorical distribution then gather the result
indices = _categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
-1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
@partial(jit, static_argnums=(2, 3))
def _von_mises_centered(key, concentration, shape, dtype=jnp.float64):
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
Returns
-------
out: array_like
centered samples from von Mises
References
----------
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
"""
shape = shape or jnp.shape(concentration)
dtype = jnp.result_type(dtype)
concentration = lax.convert_element_type(concentration, dtype)
concentration = jnp.broadcast_to(concentration, shape)
s_cutoff_map = {
jnp.dtype(jnp.float16): 1.8e-1,
jnp.dtype(jnp.float32): 2e-2,
jnp.dtype(jnp.float64): 1.2e-4,
}
s_cutoff = s_cutoff_map.get(dtype)
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
s_approximate = 1.0 / concentration
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
def cond_fn(*args):
"""check if all are done or reached max number of iterations"""
i, _, done, _, _ = args[0]
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
def body_fn(*args):
i, key, done, _, w = args[0]
uni_ukey, uni_vkey, key = jr.split(key, 3)
u = jr.uniform(
key=uni_ukey,
shape=shape,
dtype=concentration.dtype,
minval=-1.0,
maxval=1.0,
)
z = jnp.cos(jnp.pi * u)
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
y = concentration * (s - w)
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
return i + 1, key, accept | done, u, w
init_done = jnp.zeros(shape, dtype=bool)
init_u = jnp.zeros(shape)
init_w = jnp.zeros(shape)
_, _, done, u, w = lax.while_loop(
cond_fun=cond_fn,
body_fun=body_fn,
init_val=(jnp.array(0), key, init_done, init_u, init_w),
)
return jnp.sign(u) * jnp.arccos(w)
def _loc_scale(loc, scale, value):
if loc is None:
if scale is None:
return value
else:
return value * scale
else:
if scale is None:
return value + loc
else:
return value * scale + loc
def _check_py_seq(seq):
return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
[docs]@register_pytree_node_class
class RandomState(Variable):
"""RandomState that track the random generator state. """
__slots__ = ()
def __init__(
self,
seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None,
seed: Optional[int] = None,
ready_to_trace: bool = True,
):
"""RandomState constructor.
Parameters
----------
seed_or_key: int, Array, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
.. versionadded:: 2.2.3.4
seed : int, ArrayType, optional
Same as `seed_or_key`.
.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is not None:
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use `seed_or_key` instead. '
'seed will be removed since 2.4.0', UserWarning)
with jax.ensure_compile_time_eval():
if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
super(RandomState, self).__init__(key, ready_to_trace=ready_to_trace)
def __repr__(self) -> str:
print_code = repr(self.value)
i = print_code.index('(')
name = self.__class__.__name__
return f'{name}(key={print_code[i:]})'
@property
def value(self):
if isinstance(self._value, ArrayImpl):
if self._value.is_deleted():
self.seed()
self._append_to_stack()
return self._value
# ------------------- #
# seed and random key #
# ------------------- #
[docs] def clone(self):
return type(self)(self.split_key())
[docs] def seed(self, seed_or_key=None, seed=None):
"""Sets a new random seed.
Parameters
----------
seed_or_key: int, ArrayType, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
.. versionadded:: 2.2.3.4
seed : int, ArrayType, optional
Same as `seed_or_key`.
.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is not None:
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use seed_or_key instead. '
'seed will be removed since 2.4.0', UserWarning)
if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
self._value = key
[docs] def split_key(self):
"""Create a new seed from the current seed.
"""
if not isinstance(self.value, jnp.ndarray):
self._value = jnp.asarray(self.value)
keys = jr.split(self.value, num=2)
self._value = keys[0]
return keys[1]
[docs] def split_keys(self, n):
"""Create multiple seeds from the current seed. This is used
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
Parameters
----------
n : int
The number of seeds to generate.
"""
keys = jr.split(self.value, n + 1)
self._value = keys[0]
return keys[1:]
# ---------------- #
# random functions #
# ---------------- #
[docs] def rand(self, *dn, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=dn, minval=0., maxval=1.)
return _return(r)
[docs] def randint(self, low, high=None, size=None, dtype=int, key=None):
dtype = get_int() if dtype is None else dtype
low = _as_jax_array(low)
high = _as_jax_array(high)
if high is None:
high = low
low = 0
high = _check_py_seq(high)
low = _check_py_seq(low)
if size is None:
size = lax.broadcast_shapes(jnp.shape(low),
jnp.shape(high))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.randint(key,
shape=_size2shape(size),
minval=low, maxval=high, dtype=dtype)
return _return(r)
[docs] def random_integers(self, low, high=None, size=None, key=None):
low = _as_jax_array(low)
high = _as_jax_array(high)
low = _check_py_seq(low)
high = _check_py_seq(high)
if high is None:
high = low
low = 1
high += 1
if size is None:
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.randint(key,
shape=_size2shape(size),
minval=low,
maxval=high)
return _return(r)
[docs] def randn(self, *dn, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=dn)
return _return(r)
[docs] def random(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.)
return _return(r)
[docs] def random_sample(self, size=None, key=None):
r = self.random(size=size, key=key)
return _return(r)
[docs] def ranf(self, size=None, key=None):
r = self.random(size=size, key=key)
return _return(r)
[docs] def sample(self, size=None, key=None):
r = self.random(size=size, key=key)
return _return(r)
[docs] def choice(self, a, size=None, replace=True, p=None, key=None):
a = _as_jax_array(a)
p = _as_jax_array(p)
a = _check_py_seq(a)
p = _check_py_seq(p)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
return _return(r)
[docs] def permutation(self, x, axis: int = 0, independent: bool = False, key=None):
x = x.value if isinstance(x, Array) else x
x = _check_py_seq(x)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.permutation(key, x, axis=axis, independent=independent)
return _return(r)
[docs] def shuffle(self, x, axis=0, key=None):
if not isinstance(x, Array):
raise TypeError('This numpy operator needs in-place updating, therefore '
'inputs should be brainpy Array.')
key = self.split_key() if key is None else _formalize_key(key)
x.value = jr.permutation(key, x.value, axis=axis)
[docs] def beta(self, a, b, size=None, key=None):
a = a.value if isinstance(a, Array) else a
b = b.value if isinstance(b, Array) else b
a = _check_py_seq(a)
b = _check_py_seq(b)
if size is None:
size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.beta(key, a=a, b=b, shape=_size2shape(size))
return _return(r)
def exponential(self, scale=None, size=None, key=None):
scale = _as_jax_array(scale)
scale = _check_py_seq(scale)
if size is None:
size = jnp.shape(scale)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.exponential(key, shape=_size2shape(size))
if scale is not None:
r = r / scale
return _return(r)
def gamma(self, shape, scale=None, size=None, key=None):
shape = _as_jax_array(shape)
scale = _as_jax_array(scale)
shape = _check_py_seq(shape)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.gamma(key, a=shape, shape=_size2shape(size))
if scale is not None:
r = r * scale
return _return(r)
def gumbel(self, loc=None, scale=None, size=None, key=None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size)))
return _return(r)
def laplace(self, loc=None, scale=None, size=None, key=None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size)))
return _return(r)
def logistic(self, loc=None, scale=None, size=None, key=None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size)))
return _return(r)
def normal(self, loc=None, scale=None, size=None, key=None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size)))
return _return(r)
def pareto(self, a, size=None, key=None):
a = _as_jax_array(a)
a = _check_py_seq(a)
if size is None:
size = jnp.shape(a)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.pareto(key, b=a, shape=_size2shape(size))
return _return(r)
def poisson(self, lam=1.0, size=None, key=None):
lam = _check_py_seq(_as_jax_array(lam))
if size is None:
size = jnp.shape(lam)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.poisson(key, lam=lam, shape=_size2shape(size))
return _return(r)
def standard_cauchy(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.cauchy(key, shape=_size2shape(size))
return _return(r)
def standard_exponential(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.exponential(key, shape=_size2shape(size))
return _return(r)
def standard_gamma(self, shape, size=None, key=None):
shape = _as_jax_array(shape)
shape = _check_py_seq(shape)
if size is None:
size = jnp.shape(shape)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.gamma(key, a=shape, shape=_size2shape(size))
return _return(r)
def standard_normal(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=_size2shape(size))
return _return(r)
def standard_t(self, df, size=None, key=None):
df = _as_jax_array(df)
df = _check_py_seq(df)
if size is None:
size = jnp.shape(size)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.t(key, df=df, shape=_size2shape(size))
return _return(r)
def uniform(self, low=0.0, high=1.0, size=None, key=None):
low = _as_jax_array(low)
high = _as_jax_array(high)
low = _check_py_seq(low)
high = _check_py_seq(high)
if size is None:
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high)
return _return(r)
[docs] def truncated_normal(self, lower, upper, size=None, scale=None, key=None):
lower = _as_jax_array(lower)
lower = _check_py_seq(lower)
upper = _as_jax_array(upper)
upper = _check_py_seq(upper)
scale = _as_jax_array(scale)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(lower),
jnp.shape(upper),
jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
rands = jr.truncated_normal(key,
lower=lower,
upper=upper,
shape=_size2shape(size))
if scale is not None:
rands = rands * scale
return _return(rands)
def _check_p(self, p):
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
[docs] def bernoulli(self, p, size=None, key=None):
p = _check_py_seq(_as_jax_array(p))
jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.shape(p)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
return _return(r)
def lognormal(self, mean=None, sigma=None, size=None, key=None):
mean = _check_py_seq(_as_jax_array(mean))
sigma = _check_py_seq(_as_jax_array(sigma))
if size is None:
size = jnp.broadcast_shapes(jnp.shape(mean),
jnp.shape(sigma))
key = self.split_key() if key is None else _formalize_key(key)
samples = jr.normal(key, shape=_size2shape(size))
samples = _loc_scale(mean, sigma, samples)
samples = jnp.exp(samples)
return _return(samples)
def binomial(self, n, p, size=None, key=None):
n = _check_py_seq(n.value if isinstance(n, Array) else n)
p = _check_py_seq(p.value if isinstance(p, Array) else p)
jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
key = self.split_key() if key is None else _formalize_key(key)
r = _binomial(key, p, n, shape=_size2shape(size))
return _return(r)
def chisquare(self, df, size=None, key=None):
df = _check_py_seq(_as_jax_array(df))
key = self.split_key() if key is None else _formalize_key(key)
if size is None:
if jnp.ndim(df) == 0:
dist = jr.normal(key, (df,)) ** 2
dist = dist.sum()
else:
raise NotImplementedError('Do not support non-scale "df" when "size" is None')
else:
dist = jr.normal(key, (df,) + _size2shape(size)) ** 2
dist = dist.sum(axis=0)
return _return(dist)
def dirichlet(self, alpha, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
alpha = _check_py_seq(_as_jax_array(alpha))
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size))
return _return(r)
def geometric(self, p, size=None, key=None):
p = _as_jax_array(p)
p = _check_py_seq(p)
if size is None:
size = jnp.shape(p)
key = self.split_key() if key is None else _formalize_key(key)
u = jr.uniform(key, size)
r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
return _return(r)
def _check_p2(self, p):
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
def multinomial(self, n, pvals, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
n = _check_py_seq(_as_jax_array(n))
pvals = _check_py_seq(_as_jax_array(pvals))
jit_error(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
if isinstance(n, jax.core.Tracer):
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
size = _size2shape(size)
n_max = int(np.max(jax.device_get(n)))
batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
r = _multinomial(key, pvals, n, n_max, batch_shape + size)
return _return(r)
def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', key=None):
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
mean = _check_py_seq(_as_jax_array(mean))
cov = _check_py_seq(_as_jax_array(cov))
key = self.split_key() if key is None else _formalize_key(key)
if not jnp.ndim(mean) >= 1:
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
if not jnp.ndim(cov) >= 2:
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
n = mean.shape[-1]
if jnp.shape(cov)[-2:] != (n, n):
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
f"but got cov.shape == {jnp.shape(cov)}.")
if size is None:
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
else:
size = _size2shape(size)
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
if method == 'svd':
(u, s, _) = jnp.linalg.svd(cov)
factor = u * jnp.sqrt(s[..., None, :])
elif method == 'eigh':
(w, v) = jnp.linalg.eigh(cov)
factor = v * jnp.sqrt(w[..., None, :])
else: # 'cholesky'
factor = jnp.linalg.cholesky(cov)
normal_samples = jr.normal(key, size + mean.shape[-1:])
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return _return(r)
def rayleigh(self, scale=1.0, size=None, key=None):
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = jnp.shape(scale)
key = self.split_key() if key is None else _formalize_key(key)
x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1)))
r = x * scale
return _return(r)
def triangular(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
r = 2 * bernoulli_samples - 1
return _return(r)
def vonmises(self, mu, kappa, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
mu = _check_py_seq(_as_jax_array(mu))
kappa = _check_py_seq(_as_jax_array(kappa))
if size is None:
size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
size = _size2shape(size)
samples = _von_mises_centered(key, kappa, size)
samples = samples + mu
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
return _return(samples)
[docs] def weibull(self, a, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
else:
if jnp.size(a) > 1:
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
size = _size2shape(size)
random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1)
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
return _return(r)
[docs] def weibull_min(self, a, scale=None, size=None, key=None):
"""Sample from a Weibull minimum distribution.
Parameters
----------
a: float, array_like
The concentration parameter of the distribution.
scale: float, array_like
The scale parameter of the distribution.
size: optional, int, tuple of int
The shape added to the parameters loc and scale broadcastable shape.
Returns
-------
out: array_like
The sampling results.
"""
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
else:
if jnp.size(a) > 1:
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
size = _size2shape(size)
random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1)
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
if scale is not None:
r /= scale
return _return(r)
[docs] def maxwell(self, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
shape = core.canonicalize_shape(_size2shape(size)) + (3,)
norm_rvs = jr.normal(key=key, shape=shape)
r = jnp.linalg.norm(norm_rvs, axis=-1)
return _return(r)
def negative_binomial(self, n, p, size=None, key=None):
n = _check_py_seq(_as_jax_array(n))
p = _check_py_seq(_as_jax_array(p))
if size is None:
size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
size = _size2shape(size)
logits = jnp.log(p) - jnp.log1p(-p)
if key is None:
keys = self.split_keys(2)
else:
keys = jr.split(_formalize_key(key), 2)
rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0])
r = self.poisson(lam=rate, key=keys[1])
return _return(r)
def wald(self, mean, scale, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
mean = _check_py_seq(_as_jax_array(mean))
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
size = _size2shape(size)
sampled_chi2 = jnp.square(_as_jax_array(self.randn(*size)))
sampled_uniform = _as_jax_array(self.uniform(size=size, key=key))
# Wikipedia defines an intermediate x with the formula
# x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
# where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
# Let us write
# w = loc * y / (2 * conc)
# Then we can extract the common factor in the last two terms to obtain
# x = loc + loc * w * (1 - sqrt(2 / w + 1))
# Now we see that the Wikipedia formula suffers from catastrphic
# cancellation for large w (e.g., if conc << loc).
#
# Fortunately, we can fix this by multiplying both sides
# by 1 + sqrt(2 / w + 1). We get
# x * (1 + sqrt(2 / w + 1)) =
# = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
# = loc * (sqrt(2 / w + 1) - 1)
# The term sqrt(2 / w + 1) + 1 no longer presents numerical
# difficulties for large w, and sqrt(2 / w + 1) - 1 is just
# sqrt1pm1(2 / w), which we know how to compute accurately.
# This just leaves the matter of small w, where 2 / w may
# overflow. In the limit a w -> 0, x -> loc, so we just mask
# that case.
sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
res = jnp.where(sampled_uniform <= mean / (mean + sampled),
sampled,
jnp.square(mean) / sampled)
return _return(res)
[docs] def t(self, df, size=None, key=None):
df = _check_py_seq(_as_jax_array(df))
if size is None:
size = np.shape(df)
else:
size = _size2shape(size)
_check_shape("t", size, np.shape(df))
if key is None:
keys = self.split_keys(2)
else:
keys = jr.split(_formalize_key(key), 2)
n = jr.normal(keys[0], size)
two = _const(n, 2)
half_df = lax.div(df, two)
g = jr.gamma(keys[1], half_df, size)
r = n * jnp.sqrt(half_df / g)
return _return(r)
[docs] def orthogonal(self, n: int, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
size = _size2shape(size)
_check_shape("orthogonal", size)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = jr.normal(key, size + (n, n))
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
r = q * jnp.expand_dims(d / abs(d), -2)
return _return(r)
def noncentral_chisquare(self, df, nonc, size=None, key=None):
df = _check_py_seq(_as_jax_array(df))
nonc = _check_py_seq(_as_jax_array(nonc))
if size is None:
size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
size = _size2shape(size)
if key is None:
keys = self.split_keys(3)
else:
keys = jr.split(_formalize_key(key), 3)
i = jr.poisson(keys[0], 0.5 * nonc, shape=size)
n = jr.normal(keys[1], shape=size) + jnp.sqrt(nonc)
cond = jnp.greater(df, 1.0)
df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size)
r = jnp.where(cond, chi2 + n * n, chi2)
return _return(r)
[docs] def loggamma(self, a, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
r = jr.loggamma(key, a, shape=_size2shape(size))
return _return(r)
[docs] def categorical(self, logits, axis: int = -1, size=None, key=None):
key = self.split_key() if key is None else _formalize_key(key)
logits = _check_py_seq(_as_jax_array(logits))
if size is None:
size = list(jnp.shape(logits))
size.pop(axis)
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
return _return(r)
[docs] def zipf(self, a, size=None, key=None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
r = call(lambda x: np.random.zipf(x, size),
a,
result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
def power(self, a, size=None, key=None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
size = _size2shape(size)
r = call(lambda a: np.random.power(a=a, size=size),
a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))
return _return(r)
def f(self, dfnum, dfden, size=None, key=None):
dfnum = _as_jax_array(dfnum)
dfden = _as_jax_array(dfden)
dfnum = _check_py_seq(dfnum)
dfden = _check_py_seq(dfden)
if size is None:
size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden}
r = call(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size),
d,
result_shape=jax.ShapeDtypeStruct(size, jnp.float_))
return _return(r)
def hypergeometric(self, ngood, nbad, nsample, size=None, key=None):
ngood = _check_py_seq(_as_jax_array(ngood))
nbad = _check_py_seq(_as_jax_array(nbad))
nsample = _check_py_seq(_as_jax_array(nsample))
if size is None:
size = lax.broadcast_shapes(jnp.shape(ngood),
jnp.shape(nbad),
jnp.shape(nsample))
size = _size2shape(size)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
r = call(lambda x: np.random.hypergeometric(ngood=x['ngood'],
nbad=x['nbad'],
nsample=x['nsample'],
size=size),
d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
def logseries(self, p, size=None, key=None):
p = _check_py_seq(_as_jax_array(p))
if size is None:
size = jnp.shape(p)
size = _size2shape(size)
r = call(lambda p: np.random.logseries(p=p, size=size),
p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))
return _return(r)
def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None):
dfnum = _check_py_seq(_as_jax_array(dfnum))
dfden = _check_py_seq(_as_jax_array(dfden))
nonc = _check_py_seq(_as_jax_array(nonc))
if size is None:
size = lax.broadcast_shapes(jnp.shape(dfnum),
jnp.shape(dfden),
jnp.shape(nonc))
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size),
d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))
return _return(r)
# PyTorch compatibility #
# --------------------- #
[docs] def rand_like(self, input, *, dtype=None, key=None):
"""Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return self.random(shape(input), key=key).astype(dtype)
[docs] def randn_like(self, input, *, dtype=None, key=None):
"""Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return self.randn(*shape(input), key=key).astype(dtype)
[docs] def randint_like(self, input, low=0, high=None, *, dtype=None, key=None):
if high is None:
high = max(input)
return self.randint(low, high=high, size=shape(input), dtype=dtype, key=key)
# alias
Generator = RandomState
# default random generator
__a = Array(None)
__a._value = np.random.randint(0, 10000, size=2, dtype=np.uint32)
DEFAULT = RandomState(__a)
del __a
[docs]def split_key():
return DEFAULT.split_key()
[docs]def split_keys(n):
"""Create multiple seeds from the current seed. This is used
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
.. versionadded:: 2.4.5
Parameters
----------
n : int
The number of seeds to generate.
"""
return DEFAULT.split_keys(n)
def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
"""Clone the random state according to the given setting.
Args:
seed_or_key: The seed (an integer) or the random key.
clone: Bool. Whether clone the default random state.
Returns:
The random state.
"""
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
return RandomState(seed_or_key)
[docs]def default_rng(seed_or_key=None, clone=True) -> RandomState:
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
return RandomState(seed_or_key)
[docs]def seed(seed: int = None):
"""Sets a new random seed.
Parameters
----------
seed: int, optional
The random seed.
"""
with jax.ensure_compile_time_eval():
if seed is None:
seed = np.random.randint(0, 100000)
np.random.seed(seed)
DEFAULT.seed(seed)
[docs]def rand(*dn, key=None):
r"""Random values in a given shape.
.. note::
This is a convenience function for users porting code from Matlab,
and wraps `random_sample`. That function takes a
tuple to specify the size of the output, which is consistent with
other NumPy functions like `numpy.zeros` and `numpy.ones`.
Create an array of the given shape and populate it with
random samples from a uniform distribution
over ``[0, 1)``.
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, must be non-negative.
If no argument is given a single Python float is returned.
Returns
-------
out : ndarray, shape ``(d0, d1, ..., dn)``
Random values.
See Also
--------
random
Examples
--------
>>> brainpy.math.random.rand(3,2)
array([[ 0.14022471, 0.96360618], #random
[ 0.37601032, 0.25528411], #random
[ 0.49313049, 0.94909878]]) #random
"""
return DEFAULT.rand(*dn, key=key)
[docs]def randint(low, high=None, size=None, dtype=int, key=None):
r"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
the specified dtype in the "half-open" interval [`low`, `high`). If
`high` is None (the default), then results are from [0, `low`).
Parameters
----------
low : int or array-like of ints
Lowest (signed) integers to be drawn from the distribution (unless
``high=None``, in which case this parameter is one above the
*highest* such integer).
high : int or array-like of ints, optional
If provided, one above the largest (signed) integer to be drawn
from the distribution (see above for behavior if ``high=None``).
If array-like, must contain integer values
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
dtype : dtype, optional
Desired dtype of the result. Byteorder must be native.
The default value is int.
Returns
-------
out : int or ndarray of ints
`size`-shaped array of random integers from the appropriate
distribution, or a single such random int if `size` not provided.
See Also
--------
random_integers : similar to `randint`, only for the closed
interval [`low`, `high`], and 1 is the lowest value if `high` is
omitted.
Generator.integers: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.randint(2, size=10)
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
>>> bm.random.randint(1, size=10)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Generate a 2 x 4 array of ints between 0 and 4, inclusive:
>>> bm.random.randint(5, size=(2, 4))
array([[4, 0, 2, 1], # random
[3, 2, 2, 0]])
Generate a 1 x 3 array with 3 different upper bounds
>>> bm.random.randint(1, [3, 5, 10])
array([2, 2, 9]) # random
Generate a 1 by 3 array with 3 different lower bounds
>>> bm.random.randint([1, 5, 7], 10)
array([9, 8, 7]) # random
Generate a 2 by 4 array using broadcasting with dtype of uint8
>>> bm.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
array([[ 8, 6, 9, 7], # random
[ 1, 16, 9, 12]], dtype=uint8)
"""
return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key)
[docs]def random_integers(low, high=None, size=None, key=None):
r"""
Random integers of type `np.int_` between `low` and `high`, inclusive.
Return random integers of type `np.int_` from the "discrete uniform"
distribution in the closed interval [`low`, `high`]. If `high` is
None (the default), then results are from [1, `low`]. The `np.int_`
type translates to the C long integer type and its precision
is platform dependent.
Parameters
----------
low : int
Lowest (signed) integer to be drawn from the distribution (unless
``high=None``, in which case this parameter is the *highest* such
integer).
high : int, optional
If provided, the largest (signed) integer to be drawn from the
distribution (see above for behavior if ``high=None``).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : int or ndarray of ints
`size`-shaped array of random integers from the appropriate
distribution, or a single such random int if `size` not provided.
See Also
--------
randint : Similar to `random_integers`, only for the half-open
interval [`low`, `high`), and 0 is the lowest value if `high` is
omitted.
Notes
-----
To sample from N evenly spaced floating-point numbers between a and b,
use::
a + (b - a) * (bm.random.random_integers(N) - 1) / (N - 1.)
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.random_integers(5)
4 # random
>>> type(bm.random.random_integers(5))
<class 'numpy.int64'>
>>> bm.random.random_integers(5, size=(3,2))
array([[5, 4], # random
[3, 3],
[4, 5]])
Choose five random numbers from the set of five evenly-spaced
numbers between 0 and 2.5, inclusive (*i.e.*, from the set
:math:`{0, 5/8, 10/8, 15/8, 20/8}`):
>>> 2.5 * (bm.random.random_integers(5, size=(5,)) - 1) / 4.
array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
Roll two six sided dice 1000 times and sum the results:
>>> d1 = bm.random.random_integers(1, 6, 1000)
>>> d2 = bm.random.random_integers(1, 6, 1000)
>>> dsums = d1 + d2
Display results as a histogram:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(dsums, 11, density=True)
>>> plt.show()
"""
return DEFAULT.random_integers(low, high=high, size=size, key=key)
[docs]def randn(*dn, key=None):
r"""
Return a sample (or samples) from the "standard normal" distribution.
.. note::
This is a convenience function for users porting code from Matlab,
and wraps `standard_normal`. That function takes a
tuple to specify the size of the output, which is consistent with
other NumPy functions like `numpy.zeros` and `numpy.ones`.
.. note::
New code should use the ``standard_normal`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
If positive int_like arguments are provided, `randn` generates an array
of shape ``(d0, d1, ..., dn)``, filled
with random floats sampled from a univariate "normal" (Gaussian)
distribution of mean 0 and variance 1. A single float randomly sampled
from the distribution is returned if no argument is provided.
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, must be non-negative.
If no argument is given a single Python float is returned.
Returns
-------
Z : ndarray or float
A ``(d0, d1, ..., dn)``-shaped array of floating-point samples from
the standard normal distribution, or a single such float if
no parameters were supplied.
See Also
--------
standard_normal : Similar, but takes a tuple as its argument.
normal : Also accepts mu and sigma arguments.
random.Generator.standard_normal: which should be used for new code.
Notes
-----
For random samples from :math:`N(\mu, \sigma^2)`, use:
``sigma * bm.random.randn(...) + mu``
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.randn()
2.1923875335537315 # random
Two-by-four array of samples from N(3, 6.25):
>>> 3 + 2.5 * bm.random.randn(2, 4)
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
"""
return DEFAULT.randn(*dn, key=key)
[docs]def random(size=None, key=None):
"""
Return random floats in the half-open interval [0.0, 1.0). Alias for
`random_sample` to ease forward-porting to the new random API.
"""
return DEFAULT.random(size, key=key)
[docs]def random_sample(size=None, key=None):
r"""
Return random floats in the half-open interval [0.0, 1.0).
Results are from the "continuous uniform" distribution over the
stated interval. To sample :math:`Unif[a, b), b > a` multiply
the output of `random_sample` by `(b-a)` and add `a`::
(b - a) * random_sample() + a
.. note::
New code should use the ``random`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : float or ndarray of floats
Array of random floats of shape `size` (unless ``size=None``, in which
case a single float is returned).
See Also
--------
Generator.random: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.random_sample()
0.47108547995356098 # random
>>> type(bm.random.random_sample())
<class 'float'>
>>> bm.random.random_sample((5,))
array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
Three-by-two array of random numbers from [-5, 0):
>>> 5 * bm.random.random_sample((3, 2)) - 5
array([[-3.99149989, -0.52338984], # random
[-2.99091858, -0.79479508],
[-1.23204345, -1.75224494]])
"""
return DEFAULT.random_sample(size, key=key)
[docs]def ranf(size=None, key=None):
"""
This is an alias of `random_sample`. See `random_sample` for the complete
documentation.
"""
return DEFAULT.ranf(size, key=key)
[docs]def sample(size=None, key=None):
"""
This is an alias of `random_sample`. See `random_sample` for the complete
documentation.
"""
return DEFAULT.sample(size, key=key)
[docs]def choice(a, size=None, replace=True, p=None, key=None):
r"""
Generates a random sample from a given 1-D array
Parameters
----------
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements.
If an int, the random sample is generated as if it were ``np.arange(a)``
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement. Default is True,
meaning that a value of ``a`` can be selected multiple times.
p : 1-D array-like, optional
The probabilities associated with each entry in a.
If not given, the sample assumes a uniform distribution over all
entries in ``a``.
Returns
-------
samples : single item or ndarray
The generated random samples
Raises
------
ValueError
If a is an int and less than zero, if a or p are not 1-dimensional,
if a is an array-like of size 0, if p is not a vector of
probabilities, if a and p have different lengths, or if
replace=False and the sample size is greater than the population
size
See Also
--------
randint, shuffle, permutation
Generator.choice: which should be used in new code
Notes
-----
Setting user-specified probabilities through ``p`` uses a more general but less
efficient sampler than the default. The general sampler produces a different sample
than the optimized sampler even if each element of ``p`` is 1 / len(a).
Sampling random rows from a 2-D array is not possible with this function,
but is possible with `Generator.choice` through its ``axis`` keyword.
Examples
--------
Generate a uniform random sample from np.arange(5) of size 3:
>>> import brainpy.math as bm
>>> bm.random.choice(5, 3)
array([0, 3, 4]) # random
>>> #This is equivalent to brainpy.math.random.randint(0,5,3)
Generate a non-uniform random sample from np.arange(5) of size 3:
>>> bm.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
array([3, 3, 0]) # random
Generate a uniform random sample from np.arange(5) of size 3 without
replacement:
>>> bm.random.choice(5, 3, replace=False)
array([3,1,0]) # random
>>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
Generate a non-uniform random sample from np.arange(5) of size
3 without replacement:
>>> bm.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
array([2, 3, 0]) # random
Any of the above can be repeated with an arbitrary array-like
instead of just integers. For instance:
>>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
>>> bm.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
dtype='<U11')
"""
a = _as_jax_array(a)
return DEFAULT.choice(a=a, size=size, replace=replace, p=p, key=key)
[docs]def permutation(x, axis: int = 0, independent: bool = False, key=None):
r"""
Randomly permute a sequence, or return a permuted range.
If `x` is a multi-dimensional array, it is only shuffled along its
first index.
Parameters
----------
x : int or array_like
If `x` is an integer, randomly permute ``np.arange(x)``.
If `x` is an array, make a copy and shuffle the elements
randomly.
Returns
-------
out : ndarray
Permuted sequence or array range.
See Also
--------
random.Generator.permutation: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.permutation(10)
array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
>>> bm.random.permutation([1, 4, 9, 12, 15])
array([15, 1, 9, 4, 12]) # random
>>> arr = np.arange(9).reshape((3, 3))
>>> bm.random.permutation(arr)
array([[6, 7, 8], # random
[0, 1, 2],
[3, 4, 5]])
"""
return DEFAULT.permutation(x, axis=axis, independent=independent, key=key)
[docs]def shuffle(x, axis=0, key=None):
r"""
Modify a sequence in-place by shuffling its contents.
This function only shuffles the array along the first axis of a
multi-dimensional array. The order of sub-arrays is changed but
their contents remains the same.
Parameters
----------
x : ndarray or MutableSequence
The array, list or mutable sequence to be shuffled.
Returns
-------
None
See Also
--------
random.Generator.shuffle: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> arr = np.arange(10)
>>> bm.random.shuffle(arr)
>>> arr
[1 7 5 2 9 4 3 6 0 8] # random
Multi-dimensional arrays are only shuffled along the first axis:
>>> arr = np.arange(9).reshape((3, 3))
>>> bm.random.shuffle(arr)
>>> arr
array([[3, 4, 5], # random
[6, 7, 8],
[0, 1, 2]])
"""
DEFAULT.shuffle(x, axis, key=key)
[docs]def beta(a, b, size=None, key=None):
r"""
Draw samples from a Beta distribution.
The Beta distribution is a special case of the Dirichlet distribution,
and is related to the Gamma distribution. It has the probability
distribution function
.. math:: f(x; a,b) = \frac{1}{B(\alpha, \beta)} x^{\alpha - 1}
(1 - x)^{\beta - 1},
where the normalization, B, is the beta function,
.. math:: B(\alpha, \beta) = \int_0^1 t^{\alpha - 1}
(1 - t)^{\beta - 1} dt.
It is often seen in Bayesian inference and order statistics.
Parameters
----------
a : float or array_like of floats
Alpha, positive (>0).
b : float or array_like of floats
Beta, positive (>0).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` and ``b`` are both scalars.
Otherwise, ``np.broadcast(a, b).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized beta distribution.
See Also
--------
random.Generator.beta: which should be used for new code.
"""
return DEFAULT.beta(a, b, size=size, key=key)
# @wraps(np.random.exponential)
[docs]def exponential(scale=None, size=None, key=None):
return DEFAULT.exponential(scale, size, key=key)
# @wraps(np.random.gamma)
[docs]def gamma(shape, scale=None, size=None, key=None):
return DEFAULT.gamma(shape, scale, size=size, key=key)
# @wraps(np.random.gumbel)
[docs]def gumbel(loc=None, scale=None, size=None, key=None):
return DEFAULT.gumbel(loc, scale, size=size, key=key)
# @wraps(np.random.laplace)
[docs]def laplace(loc=None, scale=None, size=None, key=None):
return DEFAULT.laplace(loc, scale, size, key=key)
# @wraps(np.random.logistic)
[docs]def logistic(loc=None, scale=None, size=None, key=None):
return DEFAULT.logistic(loc, scale, size, key=key)
# @wraps(np.random.normal)
[docs]def normal(loc=None, scale=None, size=None, key=None):
return DEFAULT.normal(loc, scale, size, key=key)
# @wraps(np.random.pareto)
[docs]def pareto(a, size=None, key=None):
return DEFAULT.pareto(a, size, key=key)
# @wraps(np.random.poisson)
[docs]def poisson(lam=1.0, size=None, key=None):
return DEFAULT.poisson(lam, size, key=key)
# @wraps(np.random.standard_cauchy)
[docs]def standard_cauchy(size=None, key=None):
return DEFAULT.standard_cauchy(size, key=key)
# @wraps(np.random.standard_exponential)
[docs]def standard_exponential(size=None, key=None):
return DEFAULT.standard_exponential(size, key=key)
# @wraps(np.random.standard_gamma)
[docs]def standard_gamma(shape, size=None, key=None):
return DEFAULT.standard_gamma(shape, size, key=key)
# @wraps(np.random.standard_normal)
[docs]def standard_normal(size=None, key=None):
return DEFAULT.standard_normal(size, key=key)
# @wraps(np.random.standard_t)
[docs]def standard_t(df, size=None, key=None):
return DEFAULT.standard_t(df, size, key=key)
# @wraps(np.random.uniform)
[docs]def truncated_normal(lower, upper, size=None, scale=None, key=None):
"""Sample truncated standard normal random values with given shape and dtype.
Parameters
----------
lower : float, ndarray
A float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper : float, ndarray
A float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
size : optional, list of int, tuple of int
A tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
default (None) produces a result shape by broadcasting ``lower`` and
``upper``.
scale : float, ndarray
Standard deviation (spread or "width") of the distribution. Must be
non-negative.
Returns
-------
out : Array
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
return DEFAULT.truncated_normal(lower, upper, size, scale, key=key)
[docs]def bernoulli(p=0.5, size=None, key=None):
"""Sample Bernoulli random values with given shape and mean.
Parameters
----------
p: float, array_like, optional
A float or array of floats for the mean of the random
variables. Must be broadcast-compatible with ``shape`` and the values
should be within [0, 1]. Default 0.5.
size: optional, tuple of int, int
A tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
produces a result shape equal to ``p.shape``.
Returns
-------
out: array_like
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
return DEFAULT.bernoulli(p, size, key=key)
# @wraps(np.random.lognormal)
[docs]def lognormal(mean=None, sigma=None, size=None, key=None):
return DEFAULT.lognormal(mean, sigma, size, key=key)
# @wraps(np.random.binomial)
[docs]def binomial(n, p, size=None, key=None):
return DEFAULT.binomial(n, p, size, key=key)
# @wraps(np.random.chisquare)
[docs]def chisquare(df, size=None, key=None):
return DEFAULT.chisquare(df, size, key=key)
# @wraps(np.random.dirichlet)
[docs]def dirichlet(alpha, size=None, key=None):
return DEFAULT.dirichlet(alpha, size, key=key)
# @wraps(np.random.geometric)
[docs]def geometric(p, size=None, key=None):
return DEFAULT.geometric(p, size, key=key)
# @wraps(np.random.f)
[docs]def f(dfnum, dfden, size=None, key=None):
return DEFAULT.f(dfnum, dfden, size, key=key)
# @wraps(np.random.hypergeometric)
[docs]def hypergeometric(ngood, nbad, nsample, size=None, key=None):
return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key)
# @wraps(np.random.logseries)
[docs]def logseries(p, size=None, key=None):
return DEFAULT.logseries(p, size, key=key)
# @wraps(np.random.multinomial)
[docs]def multinomial(n, pvals, size=None, key=None):
return DEFAULT.multinomial(n, pvals, size, key=key)
# @wraps(np.random.multivariate_normal)
[docs]def multivariate_normal(mean, cov, size=None, method: str = 'cholesky', key=None):
return DEFAULT.multivariate_normal(mean, cov, size, method, key=key)
# @wraps(np.random.negative_binomial)
[docs]def negative_binomial(n, p, size=None, key=None):
return DEFAULT.negative_binomial(n, p, size, key=key)
# @wraps(np.random.noncentral_chisquare)
[docs]def noncentral_chisquare(df, nonc, size=None, key=None):
return DEFAULT.noncentral_chisquare(df, nonc, size, key=key)
# @wraps(np.random.noncentral_f)
[docs]def noncentral_f(dfnum, dfden, nonc, size=None, key=None):
return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key)
# @wraps(np.random.power)
[docs]def power(a, size=None, key=None):
return DEFAULT.power(a, size, key=key)
# @wraps(np.random.rayleigh)
[docs]def rayleigh(scale=1.0, size=None, key=None):
return DEFAULT.rayleigh(scale, size, key=key)
# @wraps(np.random.triangular)
[docs]def triangular(size=None, key=None):
return DEFAULT.triangular(size, key=key)
# @wraps(np.random.vonmises)
[docs]def vonmises(mu, kappa, size=None, key=None):
return DEFAULT.vonmises(mu, kappa, size, key=key)
# @wraps(np.random.wald)
[docs]def wald(mean, scale, size=None, key=None):
return DEFAULT.wald(mean, scale, size, key=key)
[docs]def weibull(a, size=None, key=None):
r"""
Draw samples from a Weibull distribution.
Draw samples from a 1-parameter Weibull distribution with the given
shape parameter `a`.
.. math:: X = (-ln(U))^{1/a}
Here, U is drawn from the uniform distribution over (0,1].
The more common 2-parameter Weibull, including a scale parameter
:math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`.
.. note::
New code should use the ``weibull`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
a : float or array_like of floats
Shape parameter of the distribution. Must be nonnegative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Weibull distribution.
See Also
--------
scipy.stats.weibull_max
scipy.stats.weibull_min
scipy.stats.genextreme
gumbel
random.Generator.weibull: which should be used for new code.
Notes
-----
The Weibull (or Type III asymptotic extreme value distribution
for smallest values, SEV Type III, or Rosin-Rammler
distribution) is one of a class of Generalized Extreme Value
(GEV) distributions used in modeling extreme value problems.
This class includes the Gumbel and Frechet distributions.
The probability density for the Weibull distribution is
.. math:: p(x) = \frac{a}
{\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a},
where :math:`a` is the shape and :math:`\lambda` the scale.
The function has its peak (the mode) at
:math:`\lambda(\frac{a-1}{a})^{1/a}`.
When ``a = 1``, the Weibull distribution reduces to the exponential
distribution.
References
----------
.. [1] Waloddi Weibull, Royal Technical University, Stockholm,
1939 "A Statistical Theory Of The Strength Of Materials",
Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939,
Generalstabens Litografiska Anstalts Forlag, Stockholm.
.. [2] Waloddi Weibull, "A Statistical Distribution Function of
Wide Applicability", Journal Of Applied Mechanics ASME Paper
1951.
.. [3] Wikipedia, "Weibull distribution",
https://en.wikipedia.org/wiki/Weibull_distribution
Examples
--------
Draw samples from the distribution:
>>> a = 5. # shape
>>> s = brainpy.math.random.weibull(a, 1000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> x = np.arange(1,100.)/50.
>>> def weib(x,n,a):
... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a)
>>> count, bins, ignored = plt.hist(brainpy.math.random.weibull(5.,1000))
>>> x = np.arange(1,100.)/50.
>>> scale = count.max()/weib(x, 1., 5.).max()
>>> plt.plot(x, weib(x, 1., 5.)*scale)
>>> plt.show()
"""
return DEFAULT.weibull(a, size, key=key)
[docs]def weibull_min(a, scale=None, size=None, key=None):
"""Sample from a Weibull distribution.
The scipy counterpart is `scipy.stats.weibull_min`.
Args:
scale: The scale parameter of the distribution.
concentration: The concentration parameter of the distribution.
shape: The shape added to the parameters loc and scale broadcastable shape.
dtype: The type used for samples.
key: a PRNG key or a seed.
Returns:
A jnp.array of samples.
"""
return DEFAULT.weibull_min(a, scale, size, key=key)
[docs]def zipf(a, size=None, key=None):
r"""
Draw samples from a Zipf distribution.
Samples are drawn from a Zipf distribution with specified parameter
`a` > 1.
The Zipf distribution (also known as the zeta distribution) is a
discrete probability distribution that satisfies Zipf's law: the
frequency of an item is inversely proportional to its rank in a
frequency table.
.. note::
New code should use the ``zipf`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
a : float or array_like of floats
Distribution parameter. Must be greater than 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Zipf distribution.
See Also
--------
scipy.stats.zipf : probability density function, distribution, or
cumulative density function, etc.
random.Generator.zipf: which should be used for new code.
Notes
-----
The probability density for the Zipf distribution is
.. math:: p(k) = \frac{k^{-a}}{\zeta(a)},
for integers :math:`k \geq 1`, where :math:`\zeta` is the Riemann Zeta
function.
It is named for the American linguist George Kingsley Zipf, who noted
that the frequency of any word in a sample of a language is inversely
proportional to its rank in the frequency table.
References
----------
.. [1] Zipf, G. K., "Selected Studies of the Principle of Relative
Frequency in Language," Cambridge, MA: Harvard Univ. Press,
1932.
Examples
--------
Draw samples from the distribution:
>>> a = 4.0
>>> n = 20000
>>> s = brainpy.math.random.zipf(a, n)
Display the histogram of the samples, along with
the expected histogram based on the probability
density function:
>>> import matplotlib.pyplot as plt
>>> from scipy.special import zeta # doctest: +SKIP
`bincount` provides a fast histogram for small integers.
>>> count = np.bincount(s)
>>> k = np.arange(1, s.max() + 1)
>>> plt.bar(k, count[1:], alpha=0.5, label='sample count')
>>> plt.plot(k, n*(k**-a)/zeta(a), 'k.-', alpha=0.5,
... label='expected count') # doctest: +SKIP
>>> plt.semilogy()
>>> plt.grid(alpha=0.4)
>>> plt.legend()
>>> plt.title(f'Zipf sample, a={a}, size={n}')
>>> plt.show()
"""
return DEFAULT.zipf(a, size, key=key)
[docs]def maxwell(size=None, key=None):
"""Sample from a one sided Maxwell distribution.
The scipy counterpart is `scipy.stats.maxwell`.
Args:
key: a PRNG key.
size: The shape of the returned samples.
dtype: The type used for samples.
Returns:
A jnp.array of samples, of shape `shape`.
"""
return DEFAULT.maxwell(size, key=key)
[docs]def t(df, size=None, key=None):
"""Sample Student’s t random values.
Parameters
----------
df: float, array_like
A float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
size: optional, int, tuple of int
A tuple of non-negative integers specifying the result shape.
Must be broadcast-compatible with `df`. The default (None) produces a result shape equal to `df.shape`.
Returns
-------
out: array_like
The sampled value.
"""
return DEFAULT.t(df, size, key=key)
[docs]def orthogonal(n: int, size=None, key=None):
"""Sample uniformly from the orthogonal group `O(n)`.
Parameters
----------
n: int
An integer indicating the resulting dimension.
size: optional, int, tuple of int
The batch dimensions of the result.
Returns
-------
out: Array
The sampled results.
"""
return DEFAULT.orthogonal(n, size, key=key)
[docs]def loggamma(a, size=None, key=None):
"""Sample log-gamma random values.
Parameters
----------
a: float, array_like
A float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
size: optional, int, tuple of int
A tuple of nonnegative integers specifying the result shape.
Must be broadcast-compatible with `a`. The default (None) produces a result shape equal to `a.shape`.
Returns
-------
out: array_like
The sampled results.
"""
return DEFAULT.loggamma(a, size)
[docs]def categorical(logits, axis: int = -1, size=None, key=None):
"""Sample random values from categorical distributions.
Args:
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that `softmax(logits, axis)` gives the corresponding probabilities.
axis: Axis along which logits belong to the same categorical distribution.
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
key: a PRNG key used as the random key.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
"""
return DEFAULT.categorical(logits, axis, size, key=key)
[docs]def rand_like(input, *, dtype=None, key=None):
"""Similar to ``rand_like`` in torch.
Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.rand_like(input, dtype=dtype, key=key)
[docs]def randn_like(input, *, dtype=None, key=None):
"""Similar to ``randn_like`` in torch.
Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.randn_like(input, dtype=dtype, key=key)
[docs]def randint_like(input, low=0, high=None, *, dtype=None, key=None):
"""Similar to ``randint_like`` in torch.
Returns a tensor with the same shape as Tensor ``input`` filled with
random integers generated uniformly between ``low`` (inclusive) and ``high`` (exclusive).
Args:
input: the ``size`` of input will determine size of the output tensor.
low: Lowest integer to be drawn from the distribution. Default: 0.
high: One above the highest integer to be drawn from the distribution.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.randint_like(input=input, low=low, high=high, dtype=dtype, key=key)
for __k in dir(RandomState):
__t = getattr(RandomState, __k)
if not __k.startswith('__') and callable(__t) and (not __t.__doc__):
__r = globals().get(__k, None)
if __r is not None and callable(__r):
__t.__doc__ = __r.__doc__