# -*- coding: utf-8 -*-
import warnings
from collections import namedtuple
from functools import partial
from operator import index
from typing import Optional, Union, Sequence, Any
import jax
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax._src.array import ArrayImpl
from jax._src.core import _canonicalize_dimension, _invalid_shape_error
from jax._src.typing import Shape
from jax.tree_util import register_pytree_node_class
from brainpy.check import jit_error_checking, jit_error_checking_no_args
from .compat_numpy import shape
from .environment import get_int
from .ndarray import Array, _return
from .object_transform.variables import Variable
__all__ = [
'RandomState', 'Generator', 'DEFAULT',
'seed', 'default_rng', 'split_key', 'split_keys',
# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta',
'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto',
'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma',
'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli',
'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f',
'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal',
'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power',
'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min',
'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', 'canonicalize_shape',
# pytorch compatibility
'rand_like', 'randint_like', 'randn_like',
]
JAX_RAND_KEY = jax.Array
def _formalize_key(key):
if isinstance(key, int):
return jr.PRNGKey(key)
elif isinstance(key, (Array, jnp.ndarray, np.ndarray)):
if key.dtype != jnp.uint32:
raise TypeError('key must be a int or an array with two uint32.')
if key.size != 2:
raise TypeError('key must be a int or an array with two uint32.')
return jnp.asarray(key)
else:
raise TypeError('key must be a int or an array with two uint32.')
def _size2shape(size):
if size is None:
return ()
elif isinstance(size, (tuple, list)):
return tuple(size)
else:
return (size,)
def _check_shape(name, shape, *param_shapes):
if param_shapes:
shape_ = lax.broadcast_shapes(shape, *param_shapes)
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
raise ValueError(msg.format(name, shape_, shape))
def _as_jax_array(a):
return a.value if isinstance(a, Array) else a
def _is_python_scalar(x):
if hasattr(x, 'aval'):
return x.aval.weak_type
elif np.ndim(x) == 0:
return True
elif isinstance(x, (bool, int, float, complex)):
return True
else:
return False
python_scalar_dtypes = {
bool: np.dtype('bool'),
int: np.dtype('int64'),
float: np.dtype('float64'),
complex: np.dtype('complex128'),
}
def _dtype(x, *, canonicalize: bool = False):
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
if x is None:
raise ValueError(f"Invalid argument to dtype: {x}.")
elif isinstance(x, type) and x in python_scalar_dtypes:
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
dt = x.dtype
else:
dt = np.result_type(x)
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
def _const(example, val):
if _is_python_scalar(example):
dtype = dtypes.canonicalize_dtype(type(example))
val = dtypes.scalar_type_of(example)(val)
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
else:
dtype = dtypes.canonicalize_dtype(example.dtype)
return np.array(val, dtype)
_tr_params = namedtuple(
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
)
def _get_tr_params(n, p):
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
mu = n * p
spq = jnp.sqrt(mu * (1 - p))
c = mu + 0.5
b = 1.15 + 2.53 * spq
a = -0.0873 + 0.0248 * b + 0.01 * p
alpha = (2.83 + 5.1 / b) * spq
u_r = 0.43
v_r = 0.92 - 4.2 / b
m = jnp.floor((n + 1) * p).astype(n.dtype)
log_p = jnp.log(p)
log1_p = jnp.log1p(-p)
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
def _stirling_approx_tail(k):
precomputed = jnp.array([0.08106146679532726,
0.04134069595540929,
0.02767792568499834,
0.02079067210376509,
0.01664469118982119,
0.01387612882307075,
0.01189670994589177,
0.01041126526197209,
0.009255462182712733,
0.008330563433362871, ])
kp1 = k + 1
kp1sq = (k + 1) ** 2
return jnp.where(k < 10,
precomputed[k],
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
def _binomial_btrs(key, p, n):
"""
Based on the transformed rejection sampling algorithm (BTRS) from the
following reference:
Hormann, "The Generation of Binonmial Random Variates"
(https://core.ac.uk/download/pdf/11007254.pdf)
"""
def _btrs_body_fn(val):
_, key, _, _ = val
key, key_u, key_v = jr.split(key, 3)
u = jr.uniform(key_u)
v = jr.uniform(key_v)
u = u - 0.5
k = jnp.floor(
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
).astype(n.dtype)
return k, key, u, v
def _btrs_cond_fn(val):
def accept_fn(k, u, v):
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
# v <= f(k) * g_grad(u) / alpha
m = tr_params.m
log_p = tr_params.log_p
log1_p = tr_params.log1_p
# See: formula for log(f(k)) at bottom of Page 5.
log_f = (
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
+ tr_params.log_h
)
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
return jnp.log((v * tr_params.alpha) / g) <= log_f
k, key, u, v = val
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
early_reject = (k < 0) | (k > n)
return lax.cond(
early_accept | early_reject,
(),
lambda _: ~early_accept,
(k, u, v),
lambda x: ~accept_fn(*x),
)
tr_params = _get_tr_params(n, p)
ret = lax.while_loop(
_btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
) # use k=-1 initially so that cond_fn returns True
return ret[0]
def _binomial_inversion(key, p, n):
def _binom_inv_body_fn(val):
i, key, geom_acc = val
key, key_u = jr.split(key)
u = jr.uniform(key_u)
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
geom_acc = geom_acc + geom
return i + 1, key, geom_acc
def _binom_inv_cond_fn(val):
i, _, geom_acc = val
return geom_acc <= n
log1_p = jnp.log1p(-p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
return ret[0]
def _binomial_dispatch(key, p, n):
def dispatch(key, p, n):
is_le_mid = p <= 0.5
pq = jnp.where(is_le_mid, p, 1 - p)
mu = n * pq
k = lax.cond(
mu < 10,
(key, pq, n),
lambda x: _binomial_inversion(*x),
(key, pq, n),
lambda x: _binomial_btrs(*x),
)
return jnp.where(is_le_mid, k, n - k)
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
return lax.cond(
cond0 & (p < 1),
(key, p, n),
lambda x: dispatch(*x),
(),
lambda _: jnp.where(cond0, n, 0),
)
@partial(jit, static_argnums=(3,))
def _binomial(key, p, n, shape):
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
# reshape to map over axis 0
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
key = jr.split(key, jnp.size(p))
if jax.default_backend() == "cpu":
ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
else:
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
return jnp.reshape(ret, shape)
@partial(jit, static_argnums=(2,))
def _categorical(key, p, shape):
# this implementation is fast when event shape is small, and slow otherwise
# Ref: https://stackoverflow.com/a/34190035
shape = shape or p.shape[:-1]
s = jnp.cumsum(p, axis=-1)
r = jr.uniform(key, shape=shape + (1,))
return jnp.sum(s < r, axis=-1)
def _scatter_add_one(operand, indices, updates):
return lax.scatter_add(
operand,
indices,
updates,
lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
),
)
def _reshape(x, shape):
if isinstance(x, (int, float, np.ndarray, np.generic)):
return np.reshape(x, shape)
else:
return jnp.reshape(x, shape)
def _promote_shapes(*args, shape=()):
# adapted from lax.lax_numpy
if len(args) < 2 and not shape:
return args
else:
shapes = [jnp.shape(arg) for arg in args]
num_dims = len(lax.broadcast_shapes(shape, *shapes))
return [
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
for arg, s in zip(args, shapes)
]
@partial(jit, static_argnums=(3, 4))
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
if n_max == 0:
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
# get indices from categorical distribution then gather the result
indices = _categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
-1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
@partial(jit, static_argnums=(2, 3))
def _von_mises_centered(key, concentration, shape, dtype=jnp.float64):
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
Returns
-------
out: array_like
centered samples from von Mises
References
----------
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
"""
shape = shape or jnp.shape(concentration)
dtype = jnp.result_type(dtype)
concentration = lax.convert_element_type(concentration, dtype)
concentration = jnp.broadcast_to(concentration, shape)
s_cutoff_map = {
jnp.dtype(jnp.float16): 1.8e-1,
jnp.dtype(jnp.float32): 2e-2,
jnp.dtype(jnp.float64): 1.2e-4,
}
s_cutoff = s_cutoff_map.get(dtype)
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
s_approximate = 1.0 / concentration
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
def cond_fn(*args):
"""check if all are done or reached max number of iterations"""
i, _, done, _, _ = args[0]
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
def body_fn(*args):
i, key, done, _, w = args[0]
uni_ukey, uni_vkey, key = jr.split(key, 3)
u = jr.uniform(
key=uni_ukey,
shape=shape,
dtype=concentration.dtype,
minval=-1.0,
maxval=1.0,
)
z = jnp.cos(jnp.pi * u)
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
y = concentration * (s - w)
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
return i + 1, key, accept | done, u, w
init_done = jnp.zeros(shape, dtype=bool)
init_u = jnp.zeros(shape)
init_w = jnp.zeros(shape)
_, _, done, u, w = lax.while_loop(
cond_fun=cond_fn,
body_fun=body_fn,
init_val=(jnp.array(0), key, init_done, init_u, init_w),
)
return jnp.sign(u) * jnp.arccos(w)
def _loc_scale(loc, scale, value):
if loc is None:
if scale is None:
return value
else:
return value * scale
else:
if scale is None:
return value + loc
else:
return value * scale + loc
def _check_py_seq(seq):
return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
def canonicalize_shape(shape: Shape, context: str = "") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
[docs]
@register_pytree_node_class
class RandomState(Variable):
"""RandomState that track the random generator state. """
__slots__ = ()
def __init__(
self,
seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None,
seed: Optional[int] = None,
ready_to_trace: bool = True,
):
"""RandomState constructor.
Parameters
----------
seed_or_key: int, Array, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
.. versionadded:: 2.2.3.4
seed : int, ArrayType, optional
Same as `seed_or_key`.
.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is not None:
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use `seed_or_key` instead. '
'seed will be removed since 2.4.0', UserWarning)
with jax.ensure_compile_time_eval():
if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
super(RandomState, self).__init__(key, ready_to_trace=ready_to_trace)
def __repr__(self) -> str:
print_code = repr(self.value)
i = print_code.index('(')
name = self.__class__.__name__
return f'{name}(key={print_code[i:]})'
@property
def value(self):
if isinstance(self._value, ArrayImpl):
if self._value.is_deleted():
self.seed()
self._append_to_stack()
return self._value
# ------------------- #
# seed and random key #
# ------------------- #
[docs]
def clone(self):
return type(self)(self.split_key())
[docs]
def seed(self, seed_or_key=None, seed=None):
"""Sets a new random seed.
Parameters
----------
seed_or_key: int, ArrayType, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
.. versionadded:: 2.2.3.4
seed : int, ArrayType, optional
Same as `seed_or_key`.
.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is not None:
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use seed_or_key instead. '
'seed will be removed since 2.4.0', UserWarning)
if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
self._value = key
[docs]
def split_key(self):
"""Create a new seed from the current seed.
"""
if not isinstance(self.value, jnp.ndarray):
self._value = jnp.asarray(self.value)
keys = jr.split(self.value, num=2)
self._value = keys[0]
return keys[1]
[docs]
def split_keys(self, n):
"""Create multiple seeds from the current seed. This is used
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
Parameters
----------
n : int
The number of seeds to generate.
"""
keys = jr.split(self.value, n + 1)
self._value = keys[0]
return keys[1:]
# ---------------- #
# random functions #
# ---------------- #
[docs]
def rand(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=dn, minval=0., maxval=1.)
return _return(r)
[docs]
def randint(self,
low,
high=None,
size: Optional[Union[int, Sequence[int]]] = None,
dtype=int, key: Optional[Union[int, JAX_RAND_KEY]] = None):
dtype = get_int() if dtype is None else dtype
low = _as_jax_array(low)
high = _as_jax_array(high)
if high is None:
high = low
low = 0
high = _check_py_seq(high)
low = _check_py_seq(low)
if size is None:
size = lax.broadcast_shapes(jnp.shape(low),
jnp.shape(high))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.randint(key,
shape=_size2shape(size),
minval=low, maxval=high, dtype=dtype)
return _return(r)
[docs]
def random_integers(self,
low,
high=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
low = _as_jax_array(low)
high = _as_jax_array(high)
low = _check_py_seq(low)
high = _check_py_seq(high)
if high is None:
high = low
low = 1
high += 1
if size is None:
size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.randint(key,
shape=_size2shape(size),
minval=low,
maxval=high)
return _return(r)
[docs]
def randn(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=dn)
return _return(r)
[docs]
def random(self,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.)
return _return(r)
[docs]
def random_sample(self,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
[docs]
def ranf(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
[docs]
def sample(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r = self.random(size=size, key=key)
return _return(r)
[docs]
def choice(self, a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _as_jax_array(a)
p = _as_jax_array(p)
a = _check_py_seq(a)
p = _check_py_seq(p)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
return _return(r)
[docs]
def permutation(self, x, axis: int = 0, independent: bool = False, key: Optional[Union[int, JAX_RAND_KEY]] = None):
x = x.value if isinstance(x, Array) else x
x = _check_py_seq(x)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.permutation(key, x, axis=axis, independent=independent)
return _return(r)
[docs]
def shuffle(self, x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None):
if not isinstance(x, Array):
raise TypeError('This numpy operator needs in-place updating, therefore '
'inputs should be brainpy Array.')
key = self.split_key() if key is None else _formalize_key(key)
x.value = jr.permutation(key, x.value, axis=axis)
[docs]
def beta(self, a, b,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = a.value if isinstance(a, Array) else a
b = b.value if isinstance(b, Array) else b
a = _check_py_seq(a)
b = _check_py_seq(b)
if size is None:
size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.beta(key, a=a, b=b, shape=_size2shape(size))
return _return(r)
[docs]
def exponential(self, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
scale = _as_jax_array(scale)
scale = _check_py_seq(scale)
if size is None:
size = jnp.shape(scale)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.exponential(key, shape=_size2shape(size))
if scale is not None:
r = r / scale
return _return(r)
[docs]
def gamma(self, shape, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
shape = _as_jax_array(shape)
scale = _as_jax_array(scale)
shape = _check_py_seq(shape)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = jr.gamma(key, a=shape, shape=_size2shape(size))
if scale is not None:
r = r * scale
return _return(r)
[docs]
def gumbel(self, loc=None, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size)))
return _return(r)
[docs]
def laplace(self, loc=None, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size)))
return _return(r)
[docs]
def logistic(self, loc=None, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size)))
return _return(r)
[docs]
def normal(self, loc=None, scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
loc = _as_jax_array(loc)
scale = _as_jax_array(scale)
loc = _check_py_seq(loc)
scale = _check_py_seq(scale)
if size is None:
size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
key = self.split_key() if key is None else _formalize_key(key)
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size)))
return _return(r)
[docs]
def pareto(self, a,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _as_jax_array(a)
a = _check_py_seq(a)
if size is None:
size = jnp.shape(a)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.pareto(key, b=a, shape=_size2shape(size))
return _return(r)
[docs]
def poisson(self, lam=1.0,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
lam = _check_py_seq(_as_jax_array(lam))
if size is None:
size = jnp.shape(lam)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.poisson(key, lam=lam, shape=_size2shape(size))
return _return(r)
[docs]
def standard_cauchy(self,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.cauchy(key, shape=_size2shape(size))
return _return(r)
[docs]
def standard_exponential(self,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.exponential(key, shape=_size2shape(size))
return _return(r)
[docs]
def standard_gamma(self,
shape,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
shape = _as_jax_array(shape)
shape = _check_py_seq(shape)
if size is None:
size = jnp.shape(shape)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.gamma(key, a=shape, shape=_size2shape(size))
return _return(r)
[docs]
def standard_normal(self,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
r = jr.normal(key, shape=_size2shape(size))
return _return(r)
[docs]
def standard_t(self, df,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _as_jax_array(df)
df = _check_py_seq(df)
if size is None:
size = jnp.shape(size)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.t(key, df=df, shape=_size2shape(size))
return _return(r)
def __norm_cdf(self, x, sqrt2, dtype):
# Computes standard normal cumulative distribution function
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
[docs]
def truncated_normal(self,
lower,
upper,
size: Optional[Union[int, Sequence[int]]] = None,
loc=0.,
scale=1.,
dtype=float,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
lower = _check_py_seq(_as_jax_array(lower))
upper = _check_py_seq(_as_jax_array(upper))
loc = _check_py_seq(_as_jax_array(loc))
scale = _check_py_seq(_as_jax_array(scale))
lower = lax.convert_element_type(lower, dtype)
upper = lax.convert_element_type(upper, dtype)
loc = lax.convert_element_type(loc, dtype)
scale = lax.convert_element_type(scale, dtype)
jit_error_checking_no_args(
jnp.any(jnp.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
ValueError("mean is more than 2 std from [lower, upper] in truncated_normal. "
"The distribution of values may be incorrect.")
)
if size is None:
size = lax.broadcast_shapes(jnp.shape(lower),
jnp.shape(upper),
jnp.shape(loc),
jnp.shape(scale))
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
sqrt2 = np.array(np.sqrt(2), dtype)
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
key = self.split_key() if key is None else _formalize_key(key)
out = jr.uniform(key, size, dtype,
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype)))
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
out = lax.erf_inv(out)
# Transform to proper mean, std
out = out * scale * sqrt2 + loc
# Clamp to ensure it's in the proper range
out = jnp.clip(out,
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
return _return(out)
def _check_p(self, p):
raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
[docs]
def bernoulli(self, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _check_py_seq(_as_jax_array(p))
jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.shape(p)
key = self.split_key() if key is None else _formalize_key(key)
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
return _return(r)
[docs]
def lognormal(self, mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
mean = _check_py_seq(_as_jax_array(mean))
sigma = _check_py_seq(_as_jax_array(sigma))
if size is None:
size = jnp.broadcast_shapes(jnp.shape(mean),
jnp.shape(sigma))
key = self.split_key() if key is None else _formalize_key(key)
samples = jr.normal(key, shape=_size2shape(size))
samples = _loc_scale(mean, sigma, samples)
samples = jnp.exp(samples)
return _return(samples)
[docs]
def binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
n = _check_py_seq(n.value if isinstance(n, Array) else n)
p = _check_py_seq(p.value if isinstance(p, Array) else p)
jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
if size is None:
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
key = self.split_key() if key is None else _formalize_key(key)
r = _binomial(key, p, n, shape=_size2shape(size))
return _return(r)
[docs]
def chisquare(self, df, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
key = self.split_key() if key is None else _formalize_key(key)
if size is None:
if jnp.ndim(df) == 0:
dist = jr.normal(key, (df,)) ** 2
dist = dist.sum()
else:
raise NotImplementedError('Do not support non-scale "df" when "size" is None')
else:
dist = jr.normal(key, (df,) + _size2shape(size)) ** 2
dist = dist.sum(axis=0)
return _return(dist)
[docs]
def dirichlet(self, alpha, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
alpha = _check_py_seq(_as_jax_array(alpha))
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size))
return _return(r)
[docs]
def geometric(self, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _as_jax_array(p)
p = _check_py_seq(p)
if size is None:
size = jnp.shape(p)
key = self.split_key() if key is None else _formalize_key(key)
u = jr.uniform(key, size)
r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
return _return(r)
def _check_p2(self, p):
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
[docs]
def multinomial(self, n, pvals, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
n = _check_py_seq(_as_jax_array(n))
pvals = _check_py_seq(_as_jax_array(pvals))
jit_error_checking(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
if isinstance(n, jax.core.Tracer):
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
size = _size2shape(size)
n_max = int(np.max(jax.device_get(n)))
batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
r = _multinomial(key, pvals, n, n_max, batch_shape + size)
return _return(r)
[docs]
def multivariate_normal(self, mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky',
key: Optional[Union[int, JAX_RAND_KEY]] = None):
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
mean = _check_py_seq(_as_jax_array(mean))
cov = _check_py_seq(_as_jax_array(cov))
key = self.split_key() if key is None else _formalize_key(key)
if not jnp.ndim(mean) >= 1:
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
if not jnp.ndim(cov) >= 2:
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
n = mean.shape[-1]
if jnp.shape(cov)[-2:] != (n, n):
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
f"but got cov.shape == {jnp.shape(cov)}.")
if size is None:
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
else:
size = _size2shape(size)
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
if method == 'svd':
(u, s, _) = jnp.linalg.svd(cov)
factor = u * jnp.sqrt(s[..., None, :])
elif method == 'eigh':
(w, v) = jnp.linalg.eigh(cov)
factor = v * jnp.sqrt(w[..., None, :])
else: # 'cholesky'
factor = jnp.linalg.cholesky(cov)
normal_samples = jr.normal(key, size + mean.shape[-1:])
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return _return(r)
[docs]
def rayleigh(self, scale=1.0, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = jnp.shape(scale)
key = self.split_key() if key is None else _formalize_key(key)
x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1)))
r = x * scale
return _return(r)
[docs]
def triangular(self, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
r = 2 * bernoulli_samples - 1
return _return(r)
[docs]
def vonmises(self, mu, kappa, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
mu = _check_py_seq(_as_jax_array(mu))
kappa = _check_py_seq(_as_jax_array(kappa))
if size is None:
size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
size = _size2shape(size)
samples = _von_mises_centered(key, kappa, size)
samples = samples + mu
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
return _return(samples)
[docs]
def weibull(self, a, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
else:
if jnp.size(a) > 1:
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
size = _size2shape(size)
random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1)
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
return _return(r)
[docs]
def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a Weibull minimum distribution.
Parameters
----------
a: float, array_like
The concentration parameter of the distribution.
scale: float, array_like
The scale parameter of the distribution.
size: optional, int, tuple of int
The shape added to the parameters loc and scale broadcastable shape.
Returns
-------
out: array_like
The sampling results.
"""
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
else:
if jnp.size(a) > 1:
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
size = _size2shape(size)
random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1)
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
if scale is not None:
r /= scale
return _return(r)
[docs]
def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
shape = canonicalize_shape(_size2shape(size)) + (3,)
norm_rvs = jr.normal(key=key, shape=shape)
r = jnp.linalg.norm(norm_rvs, axis=-1)
return _return(r)
[docs]
def negative_binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
n = _check_py_seq(_as_jax_array(n))
p = _check_py_seq(_as_jax_array(p))
if size is None:
size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
size = _size2shape(size)
logits = jnp.log(p) - jnp.log1p(-p)
if key is None:
keys = self.split_keys(2)
else:
keys = jr.split(_formalize_key(key), 2)
rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0])
r = self.poisson(lam=rate, key=keys[1])
return _return(r)
[docs]
def wald(self, mean, scale, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
mean = _check_py_seq(_as_jax_array(mean))
scale = _check_py_seq(_as_jax_array(scale))
if size is None:
size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
size = _size2shape(size)
sampled_chi2 = jnp.square(_as_jax_array(self.randn(*size)))
sampled_uniform = _as_jax_array(self.uniform(size=size, key=key))
# Wikipedia defines an intermediate x with the formula
# x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
# where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
# Let us write
# w = loc * y / (2 * conc)
# Then we can extract the common factor in the last two terms to obtain
# x = loc + loc * w * (1 - sqrt(2 / w + 1))
# Now we see that the Wikipedia formula suffers from catastrphic
# cancellation for large w (e.g., if conc << loc).
#
# Fortunately, we can fix this by multiplying both sides
# by 1 + sqrt(2 / w + 1). We get
# x * (1 + sqrt(2 / w + 1)) =
# = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
# = loc * (sqrt(2 / w + 1) - 1)
# The term sqrt(2 / w + 1) + 1 no longer presents numerical
# difficulties for large w, and sqrt(2 / w + 1) - 1 is just
# sqrt1pm1(2 / w), which we know how to compute accurately.
# This just leaves the matter of small w, where 2 / w may
# overflow. In the limit a w -> 0, x -> loc, so we just mask
# that case.
sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
res = jnp.where(sampled_uniform <= mean / (mean + sampled),
sampled,
jnp.square(mean) / sampled)
return _return(res)
[docs]
def t(self, df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
if size is None:
size = np.shape(df)
else:
size = _size2shape(size)
_check_shape("t", size, np.shape(df))
if key is None:
keys = self.split_keys(2)
else:
keys = jr.split(_formalize_key(key), 2)
n = jr.normal(keys[0], size)
two = _const(n, 2)
half_df = lax.div(df, two)
g = jr.gamma(keys[1], half_df, size)
r = n * jnp.sqrt(half_df / g)
return _return(r)
[docs]
def orthogonal(self, n: int, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
size = _size2shape(size)
_check_shape("orthogonal", size)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = jr.normal(key, size + (n, n))
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
r = q * jnp.expand_dims(d / abs(d), -2)
return _return(r)
[docs]
def noncentral_chisquare(self, df, nonc, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
df = _check_py_seq(_as_jax_array(df))
nonc = _check_py_seq(_as_jax_array(nonc))
if size is None:
size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
size = _size2shape(size)
if key is None:
keys = self.split_keys(3)
else:
keys = jr.split(_formalize_key(key), 3)
i = jr.poisson(keys[0], 0.5 * nonc, shape=size)
n = jr.normal(keys[1], shape=size) + jnp.sqrt(nonc)
cond = jnp.greater(df, 1.0)
df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size)
r = jnp.where(cond, chi2 + n * n, chi2)
return _return(r)
[docs]
def loggamma(self, a, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
r = jr.loggamma(key, a, shape=_size2shape(size))
return _return(r)
[docs]
def categorical(self, logits, axis: int = -1, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
logits = _check_py_seq(_as_jax_array(logits))
if size is None:
size = list(jnp.shape(logits))
size.pop(axis)
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
return _return(r)
[docs]
def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)
[docs]
def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
a = _check_py_seq(_as_jax_array(a))
if size is None:
size = jnp.shape(a)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)
[docs]
def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
dfnum = _as_jax_array(dfnum)
dfden = _as_jax_array(dfden)
dfnum = _check_py_seq(dfnum)
dfden = _check_py_seq(dfden)
if size is None:
size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)
[docs]
def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
ngood = _check_py_seq(_as_jax_array(ngood))
nbad = _check_py_seq(_as_jax_array(nbad))
nsample = _check_py_seq(_as_jax_array(nsample))
if size is None:
size = lax.broadcast_shapes(jnp.shape(ngood),
jnp.shape(nbad),
jnp.shape(nsample))
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)
[docs]
def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
p = _check_py_seq(_as_jax_array(p))
if size is None:
size = jnp.shape(p)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
p)
return _return(r)
[docs]
def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
dfnum = _check_py_seq(_as_jax_array(dfnum))
dfden = _check_py_seq(_as_jax_array(dfden))
nonc = _check_py_seq(_as_jax_array(nonc))
if size is None:
size = lax.broadcast_shapes(jnp.shape(dfnum),
jnp.shape(dfden),
jnp.shape(nonc))
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)
# PyTorch compatibility #
# --------------------- #
[docs]
def rand_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return self.random(shape(input), key=key).astype(dtype)
[docs]
def randn_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return self.randn(*shape(input), key=key).astype(dtype)
[docs]
def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
if high is None:
high = max(input)
return self.randint(low, high=high, size=shape(input), dtype=dtype, key=key)
# alias
Generator = RandomState
# default random generator
__a = Array(None)
__a._value = np.random.randint(0, 10000, size=2, dtype=np.uint32)
DEFAULT = RandomState(__a)
del __a
[docs]
def split_key():
"""Create a new seed from the current seed.
This function is useful for the consistency with JAX's random paradigm."""
return DEFAULT.split_key()
[docs]
def split_keys(n):
"""Create multiple seeds from the current seed. This is used
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
.. versionadded:: 2.4.5
Parameters
----------
n : int
The number of seeds to generate.
"""
return DEFAULT.split_keys(n)
def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
"""Clone the random state according to the given setting.
Args:
seed_or_key: The seed (an integer) or the random key.
clone: Bool. Whether clone the default random state.
Returns:
The random state.
"""
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
return RandomState(seed_or_key)
[docs]
def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
return RandomState(seed_or_key)
[docs]
def seed(seed: int = None):
"""Sets a new random seed.
Parameters
----------
seed: int, optional
The random seed.
"""
with jax.ensure_compile_time_eval():
if seed is None:
seed = np.random.randint(0, 100000)
np.random.seed(seed)
DEFAULT.seed(seed)
[docs]
def rand(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Random values in a given shape.
.. note::
This is a convenience function for users porting code from Matlab,
and wraps `random_sample`. That function takes a
tuple to specify the size of the output, which is consistent with
other NumPy functions like `numpy.zeros` and `numpy.ones`.
Create an array of the given shape and populate it with
random samples from a uniform distribution
over ``[0, 1)``.
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, must be non-negative.
If no argument is given a single Python float is returned.
Returns
-------
out : ndarray, shape ``(d0, d1, ..., dn)``
Random values.
See Also
--------
random
Examples
--------
>>> brainpy.math.random.rand(3,2)
array([[ 0.14022471, 0.96360618], #random
[ 0.37601032, 0.25528411], #random
[ 0.49313049, 0.94909878]]) #random
"""
return DEFAULT.rand(*dn, key=key)
[docs]
def randint(low, high=None, size: Optional[Union[int, Sequence[int]]] = None, dtype=int,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
the specified dtype in the "half-open" interval [`low`, `high`). If
`high` is None (the default), then results are from [0, `low`).
Parameters
----------
low : int or array-like of ints
Lowest (signed) integers to be drawn from the distribution (unless
``high=None``, in which case this parameter is one above the
*highest* such integer).
high : int or array-like of ints, optional
If provided, one above the largest (signed) integer to be drawn
from the distribution (see above for behavior if ``high=None``).
If array-like, must contain integer values
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
dtype : dtype, optional
Desired dtype of the result. Byteorder must be native.
The default value is int.
Returns
-------
out : int or ndarray of ints
`size`-shaped array of random integers from the appropriate
distribution, or a single such random int if `size` not provided.
See Also
--------
random_integers : similar to `randint`, only for the closed
interval [`low`, `high`], and 1 is the lowest value if `high` is
omitted.
Generator.integers: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.randint(2, size=10)
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random
>>> bm.random.randint(1, size=10)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Generate a 2 x 4 array of ints between 0 and 4, inclusive:
>>> bm.random.randint(5, size=(2, 4))
array([[4, 0, 2, 1], # random
[3, 2, 2, 0]])
Generate a 1 x 3 array with 3 different upper bounds
>>> bm.random.randint(1, [3, 5, 10])
array([2, 2, 9]) # random
Generate a 1 by 3 array with 3 different lower bounds
>>> bm.random.randint([1, 5, 7], 10)
array([9, 8, 7]) # random
Generate a 2 by 4 array using broadcasting with dtype of uint8
>>> bm.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8)
array([[ 8, 6, 9, 7], # random
[ 1, 16, 9, 12]], dtype=uint8)
"""
return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key)
[docs]
def random_integers(low,
high=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Random integers of type `np.int_` between `low` and `high`, inclusive.
Return random integers of type `np.int_` from the "discrete uniform"
distribution in the closed interval [`low`, `high`]. If `high` is
None (the default), then results are from [1, `low`]. The `np.int_`
type translates to the C long integer type and its precision
is platform dependent.
Parameters
----------
low : int
Lowest (signed) integer to be drawn from the distribution (unless
``high=None``, in which case this parameter is the *highest* such
integer).
high : int, optional
If provided, the largest (signed) integer to be drawn from the
distribution (see above for behavior if ``high=None``).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : int or ndarray of ints
`size`-shaped array of random integers from the appropriate
distribution, or a single such random int if `size` not provided.
See Also
--------
randint : Similar to `random_integers`, only for the half-open
interval [`low`, `high`), and 0 is the lowest value if `high` is
omitted.
Notes
-----
To sample from N evenly spaced floating-point numbers between a and b,
use::
a + (b - a) * (bm.random.random_integers(N) - 1) / (N - 1.)
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.random_integers(5)
4 # random
>>> type(bm.random.random_integers(5))
<class 'numpy.int64'>
>>> bm.random.random_integers(5, size=(3,2))
array([[5, 4], # random
[3, 3],
[4, 5]])
Choose five random numbers from the set of five evenly-spaced
numbers between 0 and 2.5, inclusive (*i.e.*, from the set
:math:`{0, 5/8, 10/8, 15/8, 20/8}`):
>>> 2.5 * (bm.random.random_integers(5, size=(5,)) - 1) / 4.
array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random
Roll two six sided dice 1000 times and sum the results:
>>> d1 = bm.random.random_integers(1, 6, 1000)
>>> d2 = bm.random.random_integers(1, 6, 1000)
>>> dsums = d1 + d2
Display results as a histogram:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(dsums, 11, density=True)
>>> plt.show()
"""
return DEFAULT.random_integers(low, high=high, size=size, key=key)
[docs]
def randn(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Return a sample (or samples) from the "standard normal" distribution.
.. note::
This is a convenience function for users porting code from Matlab,
and wraps `standard_normal`. That function takes a
tuple to specify the size of the output, which is consistent with
other NumPy functions like `numpy.zeros` and `numpy.ones`.
.. note::
New code should use the ``standard_normal`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
If positive int_like arguments are provided, `randn` generates an array
of shape ``(d0, d1, ..., dn)``, filled
with random floats sampled from a univariate "normal" (Gaussian)
distribution of mean 0 and variance 1. A single float randomly sampled
from the distribution is returned if no argument is provided.
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, must be non-negative.
If no argument is given a single Python float is returned.
Returns
-------
Z : ndarray or float
A ``(d0, d1, ..., dn)``-shaped array of floating-point samples from
the standard normal distribution, or a single such float if
no parameters were supplied.
See Also
--------
standard_normal : Similar, but takes a tuple as its argument.
normal : Also accepts mu and sigma arguments.
Notes
-----
For random samples from :math:`N(\mu, \sigma^2)`, use:
``sigma * bm.random.randn(...) + mu``
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.randn()
2.1923875335537315 # random
Two-by-four array of samples from N(3, 6.25):
>>> 3 + 2.5 * bm.random.randn(2, 4)
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
"""
return DEFAULT.randn(*dn, key=key)
[docs]
def random(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Return random floats in the half-open interval [0.0, 1.0). Alias for
`random_sample` to ease forward-porting to the new random API.
"""
return DEFAULT.random(size, key=key)
[docs]
def random_sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Return random floats in the half-open interval [0.0, 1.0).
Results are from the "continuous uniform" distribution over the
stated interval. To sample :math:`Unif[a, b), b > a` multiply
the output of `random_sample` by `(b-a)` and add `a`::
(b - a) * random_sample() + a
.. note::
New code should use the ``random`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : float or ndarray of floats
Array of random floats of shape `size` (unless ``size=None``, in which
case a single float is returned).
See Also
--------
Generator.random: which should be used for new code.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.random_sample()
0.47108547995356098 # random
>>> type(bm.random.random_sample())
<class 'float'>
>>> bm.random.random_sample((5,))
array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random
Three-by-two array of random numbers from [-5, 0):
>>> 5 * bm.random.random_sample((3, 2)) - 5
array([[-3.99149989, -0.52338984], # random
[-2.99091858, -0.79479508],
[-1.23204345, -1.75224494]])
"""
return DEFAULT.random_sample(size, key=key)
[docs]
def ranf(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
This is an alias of `random_sample`. See `random_sample` for the complete
documentation.
"""
return DEFAULT.ranf(size, key=key)
[docs]
def sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""
This is an alias of `random_sample`. See `random_sample` for the complete
documentation.
"""
return DEFAULT.sample(size, key=key)
[docs]
def choice(a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Generates a random sample from a given 1-D array
Parameters
----------
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements.
If an int, the random sample is generated as if it were ``np.arange(a)``
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement. Default is True,
meaning that a value of ``a`` can be selected multiple times.
p : 1-D array-like, optional
The probabilities associated with each entry in a.
If not given, the sample assumes a uniform distribution over all
entries in ``a``.
Returns
-------
samples : single item or ndarray
The generated random samples
Raises
------
ValueError
If a is an int and less than zero, if a or p are not 1-dimensional,
if a is an array-like of size 0, if p is not a vector of
probabilities, if a and p have different lengths, or if
replace=False and the sample size is greater than the population
size
See Also
--------
randint, shuffle, permutation
Generator.choice: which should be used in new code
Notes
-----
Setting user-specified probabilities through ``p`` uses a more general but less
efficient sampler than the default. The general sampler produces a different sample
than the optimized sampler even if each element of ``p`` is 1 / len(a).
Sampling random rows from a 2-D array is not possible with this function,
but is possible with `Generator.choice` through its ``axis`` keyword.
Examples
--------
Generate a uniform random sample from np.arange(5) of size 3:
>>> import brainpy.math as bm
>>> bm.random.choice(5, 3)
array([0, 3, 4]) # random
>>> #This is equivalent to brainpy.math.random.randint(0,5,3)
Generate a non-uniform random sample from np.arange(5) of size 3:
>>> bm.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
array([3, 3, 0]) # random
Generate a uniform random sample from np.arange(5) of size 3 without
replacement:
>>> bm.random.choice(5, 3, replace=False)
array([3,1,0]) # random
>>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3]
Generate a non-uniform random sample from np.arange(5) of size
3 without replacement:
>>> bm.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
array([2, 3, 0]) # random
Any of the above can be repeated with an arbitrary array-like
instead of just integers. For instance:
>>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher']
>>> bm.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3])
array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random
dtype='<U11')
"""
a = _as_jax_array(a)
return DEFAULT.choice(a=a, size=size, replace=replace, p=p, key=key)
[docs]
def permutation(x,
axis: int = 0,
independent: bool = False,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Randomly permute a sequence, or return a permuted range.
If `x` is a multi-dimensional array, it is only shuffled along its
first index.
Parameters
----------
x : int or array_like
If `x` is an integer, randomly permute ``np.arange(x)``.
If `x` is an array, make a copy and shuffle the elements
randomly.
Returns
-------
out : ndarray
Permuted sequence or array range.
Examples
--------
>>> import brainpy.math as bm
>>> bm.random.permutation(10)
array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random
>>> bm.random.permutation([1, 4, 9, 12, 15])
array([15, 1, 9, 4, 12]) # random
>>> arr = np.arange(9).reshape((3, 3))
>>> bm.random.permutation(arr)
array([[6, 7, 8], # random
[0, 1, 2],
[3, 4, 5]])
"""
return DEFAULT.permutation(x, axis=axis, independent=independent, key=key)
[docs]
def shuffle(x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Modify a sequence in-place by shuffling its contents.
This function only shuffles the array along the first axis of a
multi-dimensional array. The order of sub-arrays is changed but
their contents remains the same.
Parameters
----------
x : ndarray or MutableSequence
The array, list or mutable sequence to be shuffled.
Returns
-------
None
Examples
--------
>>> import brainpy.math as bm
>>> arr = np.arange(10)
>>> bm.random.shuffle(arr)
>>> arr
[1 7 5 2 9 4 3 6 0 8] # random
Multi-dimensional arrays are only shuffled along the first axis:
>>> arr = np.arange(9).reshape((3, 3))
>>> bm.random.shuffle(arr)
>>> arr
array([[3, 4, 5], # random
[6, 7, 8],
[0, 1, 2]])
"""
DEFAULT.shuffle(x, axis, key=key)
[docs]
def beta(a, b, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Beta distribution.
The Beta distribution is a special case of the Dirichlet distribution,
and is related to the Gamma distribution. It has the probability
distribution function
.. math:: f(x; a,b) = \frac{1}{B(\alpha, \beta)} x^{\alpha - 1}
(1 - x)^{\beta - 1},
where the normalization, B, is the beta function,
.. math:: B(\alpha, \beta) = \int_0^1 t^{\alpha - 1}
(1 - t)^{\beta - 1} dt.
It is often seen in Bayesian inference and order statistics.
Parameters
----------
a : float or array_like of floats
Alpha, positive (>0).
b : float or array_like of floats
Beta, positive (>0).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` and ``b`` are both scalars.
Otherwise, ``np.broadcast(a, b).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized beta distribution.
"""
return DEFAULT.beta(a, b, size=size, key=key)
[docs]
def exponential(scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from an exponential distribution.
Its probability density function is
.. math:: f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}),
for ``x > 0`` and 0 elsewhere. :math:`\beta` is the scale parameter,
which is the inverse of the rate parameter :math:`\lambda = 1/\beta`.
The rate parameter is an alternative, widely used parameterization
of the exponential distribution [3]_.
The exponential distribution is a continuous analogue of the
geometric distribution. It describes many common situations, such as
the size of raindrops measured over many rainstorms [1]_, or the time
between page requests to Wikipedia [2]_.
Parameters
----------
scale : float or array_like of floats
The scale parameter, :math:`\beta = 1/\lambda`. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``scale`` is a scalar. Otherwise,
``np.array(scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized exponential distribution.
References
----------
.. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and
Random Signal Principles", 4th ed, 2001, p. 57.
.. [2] Wikipedia, "Poisson process",
https://en.wikipedia.org/wiki/Poisson_process
.. [3] Wikipedia, "Exponential distribution",
https://en.wikipedia.org/wiki/Exponential_distribution
"""
return DEFAULT.exponential(scale, size, key=key)
[docs]
def gamma(shape, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Gamma distribution.
Samples are drawn from a Gamma distribution with specified parameters,
`shape` (sometimes designated "k") and `scale` (sometimes designated
"theta"), where both parameters are > 0.
Parameters
----------
shape : float or array_like of floats
The shape of the gamma distribution. Must be non-negative.
scale : float or array_like of floats, optional
The scale of the gamma distribution. Must be non-negative.
Default is equal to 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``shape`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized gamma distribution.
Notes
-----
The probability density for the Gamma distribution is
.. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)},
where :math:`k` is the shape and :math:`\theta` the scale,
and :math:`\Gamma` is the Gamma function.
The Gamma distribution is often used to model the times to failure of
electronic components, and arises naturally in processes for which the
waiting times between Poisson distributed events are relevant.
References
----------
.. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A
Wolfram Web Resource.
http://mathworld.wolfram.com/GammaDistribution.html
.. [2] Wikipedia, "Gamma distribution",
https://en.wikipedia.org/wiki/Gamma_distribution
"""
return DEFAULT.gamma(shape, scale, size=size, key=key)
[docs]
def gumbel(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Gumbel distribution.
Draw samples from a Gumbel distribution with specified location and
scale. For more information on the Gumbel distribution, see
Notes and References below.
Parameters
----------
loc : float or array_like of floats, optional
The location of the mode of the distribution. Default is 0.
scale : float or array_like of floats, optional
The scale parameter of the distribution. Default is 1. Must be non-
negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``loc`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Gumbel distribution.
Notes
-----
The Gumbel (or Smallest Extreme Value (SEV) or the Smallest Extreme
Value Type I) distribution is one of a class of Generalized Extreme
Value (GEV) distributions used in modeling extreme value problems.
The Gumbel is a special case of the Extreme Value Type I distribution
for maximums from distributions with "exponential-like" tails.
The probability density for the Gumbel distribution is
.. math:: p(x) = \frac{e^{-(x - \mu)/ \beta}}{\beta} e^{ -e^{-(x - \mu)/
\beta}},
where :math:`\mu` is the mode, a location parameter, and
:math:`\beta` is the scale parameter.
The Gumbel (named for German mathematician Emil Julius Gumbel) was used
very early in the hydrology literature, for modeling the occurrence of
flood events. It is also used for modeling maximum wind speed and
rainfall rates. It is a "fat-tailed" distribution - the probability of
an event in the tail of the distribution is larger than if one used a
Gaussian, hence the surprisingly frequent occurrence of 100-year
floods. Floods were initially modeled as a Gaussian process, which
underestimated the frequency of extreme events.
It is one of a class of extreme value distributions, the Generalized
Extreme Value (GEV) distributions, which also includes the Weibull and
Frechet.
The function has a mean of :math:`\mu + 0.57721\beta` and a variance
of :math:`\frac{\pi^2}{6}\beta^2`.
References
----------
.. [1] Gumbel, E. J., "Statistics of Extremes,"
New York: Columbia University Press, 1958.
.. [2] Reiss, R.-D. and Thomas, M., "Statistical Analysis of Extreme
Values from Insurance, Finance, Hydrology and Other Fields,"
Basel: Birkhauser Verlag, 2001.
"""
return DEFAULT.gumbel(loc, scale, size=size, key=key)
[docs]
def laplace(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the Laplace or double exponential distribution with
specified location (or mean) and scale (decay).
The Laplace distribution is similar to the Gaussian/normal distribution,
but is sharper at the peak and has fatter tails. It represents the
difference between two independent, identically distributed exponential
random variables.
Parameters
----------
loc : float or array_like of floats, optional
The position, :math:`\mu`, of the distribution peak. Default is 0.
scale : float or array_like of floats, optional
:math:`\lambda`, the exponential decay. Default is 1. Must be non-
negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``loc`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Laplace distribution.
Notes
-----
It has the probability density function
.. math:: f(x; \mu, \lambda) = \frac{1}{2\lambda}
\exp\left(-\frac{|x - \mu|}{\lambda}\right).
The first law of Laplace, from 1774, states that the frequency
of an error can be expressed as an exponential function of the
absolute magnitude of the error, which leads to the Laplace
distribution. For many problems in economics and health
sciences, this distribution seems to model the data better
than the standard Gaussian distribution.
References
----------
.. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of
Mathematical Functions with Formulas, Graphs, and Mathematical
Tables, 9th printing," New York: Dover, 1972.
.. [2] Kotz, Samuel, et. al. "The Laplace Distribution and
Generalizations, " Birkhauser, 2001.
.. [3] Weisstein, Eric W. "Laplace Distribution."
From MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/LaplaceDistribution.html
.. [4] Wikipedia, "Laplace distribution",
https://en.wikipedia.org/wiki/Laplace_distribution
Examples
--------
Draw samples from the distribution
>>> loc, scale = 0., 1.
>>> s = bm.random.laplace(loc, scale, 1000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, 30, density=True)
>>> x = np.arange(-8., 8., .01)
>>> pdf = np.exp(-abs(x-loc)/scale)/(2.*scale)
>>> plt.plot(x, pdf)
Plot Gaussian for comparison:
>>> g = (1/(scale * np.sqrt(2 * np.pi)) *
... np.exp(-(x - loc)**2 / (2 * scale**2)))
>>> plt.plot(x,g)
"""
return DEFAULT.laplace(loc, scale, size, key=key)
[docs]
def logistic(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a logistic distribution.
Samples are drawn from a logistic distribution with specified
parameters, loc (location or mean, also median), and scale (>0).
Parameters
----------
loc : float or array_like of floats, optional
Parameter of the distribution. Default is 0.
scale : float or array_like of floats, optional
Parameter of the distribution. Must be non-negative.
Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``loc`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized logistic distribution.
Notes
-----
The probability density for the Logistic distribution is
.. math:: P(x) = P(x) = \frac{e^{-(x-\mu)/s}}{s(1+e^{-(x-\mu)/s})^2},
where :math:`\mu` = location and :math:`s` = scale.
The Logistic distribution is used in Extreme Value problems where it
can act as a mixture of Gumbel distributions, in Epidemiology, and by
the World Chess Federation (FIDE) where it is used in the Elo ranking
system, assuming the performance of each player is a logistically
distributed random variable.
References
----------
.. [1] Reiss, R.-D. and Thomas M. (2001), "Statistical Analysis of
Extreme Values, from Insurance, Finance, Hydrology and Other
Fields," Birkhauser Verlag, Basel, pp 132-133.
.. [2] Weisstein, Eric W. "Logistic Distribution." From
MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/LogisticDistribution.html
.. [3] Wikipedia, "Logistic-distribution",
https://en.wikipedia.org/wiki/Logistic_distribution
Examples
--------
Draw samples from the distribution:
>>> loc, scale = 10, 1
>>> s = bm.random.logistic(loc, scale, 10000)
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, bins=50)
# plot against distribution
>>> def logist(x, loc, scale):
... return np.exp((loc-x)/scale)/(scale*(1+np.exp((loc-x)/scale))**2)
>>> lgst_val = logist(bins, loc, scale)
>>> plt.plot(bins, lgst_val * count.max() / lgst_val.max())
>>> plt.show()
"""
return DEFAULT.logistic(loc, scale, size, key=key)
[docs]
def normal(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw random samples from a normal (Gaussian) distribution.
The probability density function of the normal distribution, first
derived by De Moivre and 200 years later by both Gauss and Laplace
independently [2]_, is often called the bell curve because of
its characteristic shape (see the example below).
The normal distributions occurs often in nature. For example, it
describes the commonly occurring distribution of samples influenced
by a large number of tiny, random disturbances, each with its own
unique distribution [2]_.
Parameters
----------
loc : float or array_like of floats
Mean ("centre") of the distribution.
scale : float or array_like of floats
Standard deviation (spread or "width") of the distribution. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``loc`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized normal distribution.
Notes
-----
The probability density for the Gaussian distribution is
.. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }}
e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} },
where :math:`\mu` is the mean and :math:`\sigma` the standard
deviation. The square of the standard deviation, :math:`\sigma^2`,
is called the variance.
The function has its peak at the mean, and its "spread" increases with
the standard deviation (the function reaches 0.607 times its maximum at
:math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that
normal is more likely to return samples lying close to the mean, rather
than those far away.
References
----------
.. [1] Wikipedia, "Normal distribution",
https://en.wikipedia.org/wiki/Normal_distribution
.. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability,
Random Variables and Random Signal Principles", 4th ed., 2001,
pp. 51, 51, 125.
Examples
--------
Draw samples from the distribution:
>>> mu, sigma = 0, 0.1 # mean and standard deviation
>>> s = bm.random.normal(mu, sigma, 1000)
Verify the mean and the variance:
>>> abs(mu - np.mean(s))
0.0 # may vary
>>> abs(sigma - np.std(s, ddof=1))
0.1 # may vary
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, 30, density=True)
>>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
... np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
... linewidth=2, color='r')
>>> plt.show()
Two-by-four array of samples from the normal distribution with
mean 3 and standard deviation 2.5:
>>> bm.random.normal(3, 2.5, size=(2, 4))
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
"""
return DEFAULT.normal(loc, scale, size, key=key)
[docs]
def pareto(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Pareto II or Lomax distribution with
specified shape.
The Lomax or Pareto II distribution is a shifted Pareto
distribution. The classical Pareto distribution can be
obtained from the Lomax distribution by adding 1 and
multiplying by the scale parameter ``m`` (see Notes). The
smallest value of the Lomax distribution is zero while for the
classical Pareto distribution it is ``mu``, where the standard
Pareto distribution has location ``mu = 1``. Lomax can also
be considered as a simplified version of the Generalized
Pareto distribution (available in SciPy), with the scale set
to one and the location set to zero.
The Pareto distribution must be greater than zero, and is
unbounded above. It is also known as the "80-20 rule". In
this distribution, 80 percent of the weights are in the lowest
20 percent of the range, while the other 20 percent fill the
remaining 80 percent of the range.
Parameters
----------
a : float or array_like of floats
Shape of the distribution. Must be positive.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Pareto distribution.
See Also
--------
scipy.stats.lomax : probability density function, distribution or
cumulative density function, etc.
scipy.stats.genpareto : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The probability density for the Pareto distribution is
.. math:: p(x) = \frac{am^a}{x^{a+1}}
where :math:`a` is the shape and :math:`m` the scale.
The Pareto distribution, named after the Italian economist
Vilfredo Pareto, is a power law probability distribution
useful in many real world problems. Outside the field of
economics it is generally referred to as the Bradford
distribution. Pareto developed the distribution to describe
the distribution of wealth in an economy. It has also found
use in insurance, web page access statistics, oil field sizes,
and many other problems, including the download frequency for
projects in Sourceforge [1]_. It is one of the so-called
"fat-tailed" distributions.
References
----------
.. [1] Francis Hunt and Paul Johnson, On the Pareto Distribution of
Sourceforge projects.
.. [2] Pareto, V. (1896). Course of Political Economy. Lausanne.
.. [3] Reiss, R.D., Thomas, M.(2001), Statistical Analysis of Extreme
Values, Birkhauser Verlag, Basel, pp 23-30.
.. [4] Wikipedia, "Pareto distribution",
https://en.wikipedia.org/wiki/Pareto_distribution
Examples
--------
Draw samples from the distribution:
>>> a, m = 3., 2. # shape and mode
>>> s = (bm.random.pareto(a, 1000) + 1) * m
Display the histogram of the samples, along with the probability
density function:
>>> import matplotlib.pyplot as plt
>>> count, bins, _ = plt.hist(s, 100, density=True)
>>> fit = a*m**a / bins**(a+1)
>>> plt.plot(bins, max(count)*fit/max(fit), linewidth=2, color='r')
>>> plt.show()
"""
return DEFAULT.pareto(a, size, key=key)
[docs]
def poisson(lam=1.0, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the binomial distribution
for large N.
Parameters
----------
lam : float or array_like of floats
Expected number of events occurring in a fixed-time interval,
must be >= 0. A sequence must be broadcastable over the requested
size.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``lam`` is a scalar. Otherwise,
``np.array(lam).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Poisson distribution.
Notes
-----
The Poisson distribution
.. math:: f(k; \lambda)=\frac{\lambda^k e^{-\lambda}}{k!}
For events with an expected separation :math:`\lambda` the Poisson
distribution :math:`f(k; \lambda)` describes the probability of
:math:`k` events occurring within the observed
interval :math:`\lambda`.
Because the output is limited to the range of the C int64 type, a
ValueError is raised when `lam` is within 10 sigma of the maximum
representable value.
References
----------
.. [1] Weisstein, Eric W. "Poisson Distribution."
From MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/PoissonDistribution.html
.. [2] Wikipedia, "Poisson distribution",
https://en.wikipedia.org/wiki/Poisson_distribution
Examples
--------
Draw samples from the distribution:
>>> import numpy as np
>>> s = bm.random.poisson(5, 10000)
Display histogram of the sample:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, 14, density=True)
>>> plt.show()
Draw each 100 values for lambda 100 and 500:
>>> s = bm.random.poisson(lam=(100., 500.), size=(100, 2))
"""
return DEFAULT.poisson(lam, size, key=key)
[docs]
def standard_cauchy(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a standard Cauchy distribution with mode = 0.
Also known as the Lorentz distribution.
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
samples : ndarray or scalar
The drawn samples.
Notes
-----
The probability density function for the full Cauchy distribution is
.. math:: P(x; x_0, \gamma) = \frac{1}{\pi \gamma \bigl[ 1+
(\frac{x-x_0}{\gamma})^2 \bigr] }
and the Standard Cauchy distribution just sets :math:`x_0=0` and
:math:`\gamma=1`
The Cauchy distribution arises in the solution to the driven harmonic
oscillator problem, and also describes spectral line broadening. It
also describes the distribution of values at which a line tilted at
a random angle will cut the x axis.
When studying hypothesis tests that assume normality, seeing how the
tests perform on data from a Cauchy distribution is a good indicator of
their sensitivity to a heavy-tailed distribution, since the Cauchy looks
very much like a Gaussian distribution, but with heavier tails.
References
----------
.. [1] NIST/SEMATECH e-Handbook of Statistical Methods, "Cauchy
Distribution",
https://www.itl.nist.gov/div898/handbook/eda/section3/eda3663.htm
.. [2] Weisstein, Eric W. "Cauchy Distribution." From MathWorld--A
Wolfram Web Resource.
http://mathworld.wolfram.com/CauchyDistribution.html
.. [3] Wikipedia, "Cauchy distribution"
https://en.wikipedia.org/wiki/Cauchy_distribution
Examples
--------
Draw samples and plot the distribution:
>>> import matplotlib.pyplot as plt
>>> s = bm.random.standard_cauchy(1000000)
>>> s = s[(s>-25) & (s<25)] # truncate distribution so it plots well
>>> plt.hist(s, bins=100)
>>> plt.show()
"""
return DEFAULT.standard_cauchy(size, key=key)
[docs]
def standard_exponential(size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the standard exponential distribution.
`standard_exponential` is identical to the exponential distribution
with a scale parameter of 1.
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : float or ndarray
Drawn samples.
Examples
--------
Output a 3x8000 array:
>>> n = bm.random.standard_exponential((3, 8000))
"""
return DEFAULT.standard_exponential(size, key=key)
[docs]
def standard_gamma(shape, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a standard Gamma distribution.
Samples are drawn from a Gamma distribution with specified parameters,
shape (sometimes designated "k") and scale=1.
Parameters
----------
shape : float or array_like of floats
Parameter, must be non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``shape`` is a scalar. Otherwise,
``np.array(shape).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized standard gamma distribution.
See Also
--------
scipy.stats.gamma : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The probability density for the Gamma distribution is
.. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)},
where :math:`k` is the shape and :math:`\theta` the scale,
and :math:`\Gamma` is the Gamma function.
The Gamma distribution is often used to model the times to failure of
electronic components, and arises naturally in processes for which the
waiting times between Poisson distributed events are relevant.
References
----------
.. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A
Wolfram Web Resource.
http://mathworld.wolfram.com/GammaDistribution.html
.. [2] Wikipedia, "Gamma distribution",
https://en.wikipedia.org/wiki/Gamma_distribution
Examples
--------
Draw samples from the distribution:
>>> shape, scale = 2., 1. # mean and width
>>> s = bm.random.standard_gamma(shape, 1000000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> import scipy.special as sps # doctest: +SKIP
>>> count, bins, ignored = plt.hist(s, 50, density=True)
>>> y = bins**(shape-1) * ((np.exp(-bins/scale))/ # doctest: +SKIP
... (sps.gamma(shape) * scale**shape))
>>> plt.plot(bins, y, linewidth=2, color='r') # doctest: +SKIP
>>> plt.show()
"""
return DEFAULT.standard_gamma(shape, size, key=key)
[docs]
def standard_normal(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a standard Normal distribution (mean=0, stdev=1).
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : float or ndarray
A floating-point array of shape ``size`` of drawn samples, or a
single sample if ``size`` was not specified.
See Also
--------
normal :
Equivalent function with additional ``loc`` and ``scale`` arguments
for setting the mean and standard deviation.
Notes
-----
For random samples from the normal distribution with mean ``mu`` and
standard deviation ``sigma``, use one of::
mu + sigma * bm.random.standard_normal(size=...)
bm.random.normal(mu, sigma, size=...)
Examples
--------
>>> bm.random.standard_normal()
2.1923875335537315 #random
>>> s = bm.random.standard_normal(8000)
>>> s
array([ 0.6888893 , 0.78096262, -0.89086505, ..., 0.49876311, # random
-0.38672696, -0.4685006 ]) # random
>>> s.shape
(8000,)
>>> s = bm.random.standard_normal(size=(3, 4, 2))
>>> s.shape
(3, 4, 2)
Two-by-four array of samples from the normal distribution with
mean 3 and standard deviation 2.5:
>>> 3 + 2.5 * bm.random.standard_normal(size=(2, 4))
array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random
[ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
"""
return DEFAULT.standard_normal(size, key=key)
[docs]
def standard_t(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a standard Student's t distribution with `df` degrees
of freedom.
A special case of the hyperbolic distribution. As `df` gets
large, the result resembles that of the standard normal
distribution (`standard_normal`).
Parameters
----------
df : float or array_like of floats
Degrees of freedom, must be > 0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``df`` is a scalar. Otherwise,
``np.array(df).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized standard Student's t distribution.
Notes
-----
The probability density function for the t distribution is
.. math:: P(x, df) = \frac{\Gamma(\frac{df+1}{2})}{\sqrt{\pi df}
\Gamma(\frac{df}{2})}\Bigl( 1+\frac{x^2}{df} \Bigr)^{-(df+1)/2}
The t test is based on an assumption that the data come from a
Normal distribution. The t test provides a way to test whether
the sample mean (that is the mean calculated from the data) is
a good estimate of the true mean.
The derivation of the t-distribution was first published in
1908 by William Gosset while working for the Guinness Brewery
in Dublin. Due to proprietary issues, he had to publish under
a pseudonym, and so he used the name Student.
References
----------
.. [1] Dalgaard, Peter, "Introductory Statistics With R",
Springer, 2002.
.. [2] Wikipedia, "Student's t-distribution"
https://en.wikipedia.org/wiki/Student's_t-distribution
Examples
--------
From Dalgaard page 83 [1]_, suppose the daily energy intake for 11
women in kilojoules (kJ) is:
>>> intake = np.array([5260., 5470, 5640, 6180, 6390, 6515, 6805, 7515, \
... 7515, 8230, 8770])
Does their energy intake deviate systematically from the recommended
value of 7725 kJ? Our null hypothesis will be the absence of deviation,
and the alternate hypothesis will be the presence of an effect that could be
either positive or negative, hence making our test 2-tailed.
Because we are estimating the mean and we have N=11 values in our sample,
we have N-1=10 degrees of freedom. We set our significance level to 95% and
compute the t statistic using the empirical mean and empirical standard
deviation of our intake. We use a ddof of 1 to base the computation of our
empirical standard deviation on an unbiased estimate of the variance (note:
the final estimate is not unbiased due to the concave nature of the square
root).
>>> np.mean(intake)
6753.636363636364
>>> intake.std(ddof=1)
1142.1232221373727
>>> t = (np.mean(intake)-7725)/(intake.std(ddof=1)/np.sqrt(len(intake)))
>>> t
-2.8207540608310198
We draw 1000000 samples from Student's t distribution with the adequate
degrees of freedom.
>>> import matplotlib.pyplot as plt
>>> s = bm.random.standard_t(10, size=1000000)
>>> h = plt.hist(s, bins=100, density=True)
Does our t statistic land in one of the two critical regions found at
both tails of the distribution?
>>> np.sum(np.abs(t) < np.abs(s)) / float(len(s))
0.018318 #random < 0.05, statistic is in critical region
The probability value for this 2-tailed test is about 1.83%, which is
lower than the 5% pre-determined significance threshold.
Therefore, the probability of observing values as extreme as our intake
conditionally on the null hypothesis being true is too low, and we reject
the null hypothesis of no deviation.
"""
return DEFAULT.standard_t(df, size, key=key)
[docs]
def truncated_normal(lower, upper, size: Optional[Union[int, Sequence[int]]] = None, loc=0., scale=1., dtype=float,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Sample truncated standard normal random values with given shape and dtype.
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Notes
-----
This distribution is the normal distribution centered on ``loc`` (default
0), with standard deviation ``scale`` (default 1), and clipped at ``a``,
``b`` standard deviations to the left, right (respectively) from ``loc``.
If ``myclip_a`` and ``myclip_b`` are clip values in the sample space (as
opposed to the number of standard deviations) then they can be converted
to the required form according to::
a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
Parameters
----------
lower : float, ndarray
A float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper : float, ndarray
A float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
loc : float, ndarray
Mean ("centre") of the distribution before truncating. Note that
the mean of the truncated distribution will not be exactly equal
to ``loc``.
size : optional, list of int, tuple of int
A tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
default (None) produces a result shape by broadcasting ``lower`` and
``upper``.
loc: optional, float, ndarray
A float or array of floats representing the mean of the
distribution. Default is 0.
scale : float, ndarray
Standard deviation (spread or "width") of the distribution. Must be
non-negative. Default is 1.
dtype: optional
The float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
key: jax.Array
The key for random generator. Consistent with the jax's random
paradigm.
Returns
-------
out : Array
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
return DEFAULT.truncated_normal(lower, upper, size, loc, scale, dtype=dtype, key=key)
RandomState.truncated_normal.__doc__ = truncated_normal.__doc__
[docs]
def bernoulli(p=0.5, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""Sample Bernoulli random values with given shape and mean.
Parameters
----------
p: float, array_like, optional
A float or array of floats for the mean of the random
variables. Must be broadcast-compatible with ``shape`` and the values
should be within [0, 1]. Default 0.5.
size: optional, tuple of int, int
A tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
produces a result shape equal to ``p.shape``.
Returns
-------
out: array_like
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
return DEFAULT.bernoulli(p, size, key=key)
[docs]
def lognormal(mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a log-normal distribution.
Draw samples from a log-normal distribution with specified mean,
standard deviation, and array shape. Note that the mean and standard
deviation are not the values for the distribution itself, but of the
underlying normal distribution it is derived from.
Parameters
----------
mean : float or array_like of floats, optional
Mean value of the underlying normal distribution. Default is 0.
sigma : float or array_like of floats, optional
Standard deviation of the underlying normal distribution. Must be
non-negative. Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mean`` and ``sigma`` are both scalars.
Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized log-normal distribution.
See Also
--------
scipy.stats.lognorm : probability density function, distribution,
cumulative density function, etc.
Notes
-----
A variable `x` has a log-normal distribution if `log(x)` is normally
distributed. The probability density function for the log-normal
distribution is:
.. math:: p(x) = \frac{1}{\sigma x \sqrt{2\pi}}
e^{(-\frac{(ln(x)-\mu)^2}{2\sigma^2})}
where :math:`\mu` is the mean and :math:`\sigma` is the standard
deviation of the normally distributed logarithm of the variable.
A log-normal distribution results if a random variable is the *product*
of a large number of independent, identically-distributed variables in
the same way that a normal distribution results if the variable is the
*sum* of a large number of independent, identically-distributed
variables.
References
----------
.. [1] Limpert, E., Stahel, W. A., and Abbt, M., "Log-normal
Distributions across the Sciences: Keys and Clues,"
BioScience, Vol. 51, No. 5, May, 2001.
https://stat.ethz.ch/~stahel/lognormal/bioscience.pdf
.. [2] Reiss, R.D. and Thomas, M., "Statistical Analysis of Extreme
Values," Basel: Birkhauser Verlag, 2001, pp. 31-32.
Examples
--------
Draw samples from the distribution:
>>> mu, sigma = 3., 1. # mean and standard deviation
>>> s = bm.random.lognormal(mu, sigma, 1000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, 100, density=True, align='mid')
>>> x = np.linspace(min(bins), max(bins), 10000)
>>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2))
... / (x * sigma * np.sqrt(2 * np.pi)))
>>> plt.plot(x, pdf, linewidth=2, color='r')
>>> plt.axis('tight')
>>> plt.show()
Demonstrate that taking the products of random samples from a uniform
distribution can be fit well by a log-normal probability density
function.
>>> # Generate a thousand samples: each is the product of 100 random
>>> # values, drawn from a normal distribution.
>>> b = []
>>> for i in range(1000):
... a = 10. + bm.random.standard_normal(100)
... b.append(np.product(a))
>>> b = np.array(b) / np.min(b) # scale values to be positive
>>> count, bins, ignored = plt.hist(b, 100, density=True, align='mid')
>>> sigma = np.std(np.log(b))
>>> mu = np.mean(np.log(b))
>>> x = np.linspace(min(bins), max(bins), 10000)
>>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2))
... / (x * sigma * np.sqrt(2 * np.pi)))
>>> plt.plot(x, pdf, color='r', linewidth=2)
>>> plt.show()
"""
return DEFAULT.lognormal(mean, sigma, size, key=key)
[docs]
def binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a binomial distribution.
Samples are drawn from a binomial distribution with specified
parameters, n trials and p probability of success where
n an integer >= 0 and p is in the interval [0,1]. (n may be
input as a float, but it is truncated to an integer in use)
Parameters
----------
n : int or array_like of ints
Parameter of the distribution, >= 0. Floats are also accepted,
but they will be truncated to integers.
p : float or array_like of floats
Parameter of the distribution, >= 0 and <=1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``n`` and ``p`` are both scalars.
Otherwise, ``np.broadcast(n, p).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized binomial distribution, where
each sample is equal to the number of successes over the n trials.
See Also
--------
scipy.stats.binom : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The probability density for the binomial distribution is
.. math:: P(N) = \binom{n}{N}p^N(1-p)^{n-N},
where :math:`n` is the number of trials, :math:`p` is the probability
of success, and :math:`N` is the number of successes.
When estimating the standard error of a proportion in a population by
using a random sample, the normal distribution works well unless the
product p*n <=5, where p = population proportion estimate, and n =
number of samples, in which case the binomial distribution is used
instead. For example, a sample of 15 people shows 4 who are left
handed, and 11 who are right handed. Then p = 4/15 = 27%. 0.27*15 = 4,
so the binomial distribution should be used in this case.
References
----------
.. [1] Dalgaard, Peter, "Introductory Statistics with R",
Springer-Verlag, 2002.
.. [2] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill,
Fifth Edition, 2002.
.. [3] Lentner, Marvin, "Elementary Applied Statistics", Bogden
and Quigley, 1972.
.. [4] Weisstein, Eric W. "Binomial Distribution." From MathWorld--A
Wolfram Web Resource.
http://mathworld.wolfram.com/BinomialDistribution.html
.. [5] Wikipedia, "Binomial distribution",
https://en.wikipedia.org/wiki/Binomial_distribution
Examples
--------
Draw samples from the distribution:
>>> n, p = 10, .5 # number of trials, probability of each trial
>>> s = bm.random.binomial(n, p, 1000)
# result of flipping a coin 10 times, tested 1000 times.
A real world example. A company drills 9 wild-cat oil exploration
wells, each with an estimated probability of success of 0.1. All nine
wells fail. What is the probability of that happening?
Let's do 20,000 trials of the model, and count the number that
generate zero positive results.
>>> sum(bm.random.binomial(9, 0.1, 20000) == 0)/20000.
# answer = 0.38885, or 38%.
"""
return DEFAULT.binomial(n, p, size, key=key)
[docs]
def chisquare(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a chi-square distribution.
When `df` independent random variables, each with standard normal
distributions (mean 0, variance 1), are squared and summed, the
resulting distribution is chi-square (see Notes). This distribution
is often used in hypothesis testing.
Parameters
----------
df : float or array_like of floats
Number of degrees of freedom, must be > 0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``df`` is a scalar. Otherwise,
``np.array(df).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized chi-square distribution.
Raises
------
ValueError
When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``)
is given.
Notes
-----
The variable obtained by summing the squares of `df` independent,
standard normally distributed random variables:
.. math:: Q = \sum_{i=0}^{\mathtt{df}} X^2_i
is chi-square distributed, denoted
.. math:: Q \sim \chi^2_k.
The probability density function of the chi-squared distribution is
.. math:: p(x) = \frac{(1/2)^{k/2}}{\Gamma(k/2)}
x^{k/2 - 1} e^{-x/2},
where :math:`\Gamma` is the gamma function,
.. math:: \Gamma(x) = \int_0^{-\infty} t^{x - 1} e^{-t} dt.
References
----------
.. [1] NIST "Engineering Statistics Handbook"
https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm
Examples
--------
>>> bm.random.chisquare(2,4)
array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random
"""
return DEFAULT.chisquare(df, size, key=key)
[docs]
def dirichlet(alpha, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the Dirichlet distribution.
Draw `size` samples of dimension k from a Dirichlet distribution. A
Dirichlet-distributed random variable can be seen as a multivariate
generalization of a Beta distribution. The Dirichlet distribution
is a conjugate prior of a multinomial distribution in Bayesian
inference.
Parameters
----------
alpha : sequence of floats, length k
Parameter of the distribution (length ``k`` for sample of
length ``k``).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
vector of length ``k`` is returned.
Returns
-------
samples : ndarray,
The drawn samples, of shape ``(size, k)``.
Raises
------
ValueError
If any value in ``alpha`` is less than or equal to zero
Notes
-----
The Dirichlet distribution is a distribution over vectors
:math:`x` that fulfil the conditions :math:`x_i>0` and
:math:`\sum_{i=1}^k x_i = 1`.
The probability density function :math:`p` of a
Dirichlet-distributed random vector :math:`X` is
proportional to
.. math:: p(x) \propto \prod_{i=1}^{k}{x^{\alpha_i-1}_i},
where :math:`\alpha` is a vector containing the positive
concentration parameters.
The method uses the following property for computation: let :math:`Y`
be a random vector which has components that follow a standard gamma
distribution, then :math:`X = \frac{1}{\sum_{i=1}^k{Y_i}} Y`
is Dirichlet-distributed
References
----------
.. [1] David McKay, "Information Theory, Inference and Learning
Algorithms," chapter 23,
http://www.inference.org.uk/mackay/itila/
.. [2] Wikipedia, "Dirichlet distribution",
https://en.wikipedia.org/wiki/Dirichlet_distribution
Examples
--------
Taking an example cited in Wikipedia, this distribution can be used if
one wanted to cut strings (each of initial length 1.0) into K pieces
with different lengths, where each piece had, on average, a designated
average length, but allowing some variation in the relative sizes of
the pieces.
>>> s = bm.random.dirichlet((10, 5, 3), 20).transpose()
>>> import matplotlib.pyplot as plt
>>> plt.barh(range(20), s[0])
>>> plt.barh(range(20), s[1], left=s[0], color='g')
>>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r')
>>> plt.title("Lengths of Strings")
"""
return DEFAULT.dirichlet(alpha, size, key=key)
[docs]
def geometric(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the geometric distribution.
Bernoulli trials are experiments with one of two outcomes:
success or failure (an example of such an experiment is flipping
a coin). The geometric distribution models the number of trials
that must be run in order to achieve success. It is therefore
supported on the positive integers, ``k = 1, 2, ...``.
The probability mass function of the geometric distribution is
.. math:: f(k) = (1 - p)^{k - 1} p
where `p` is the probability of success of an individual trial.
Parameters
----------
p : float or array_like of floats
The probability of success of an individual trial.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``p`` is a scalar. Otherwise,
``np.array(p).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized geometric distribution.
Examples
--------
Draw ten thousand values from the geometric distribution,
with the probability of an individual success equal to 0.35:
>>> z = bm.random.geometric(p=0.35, size=10000)
How many trials succeeded after a single run?
>>> (z == 1).sum() / 10000.
0.34889999999999999 #random
"""
return DEFAULT.geometric(p, size, key=key)
[docs]
def f(dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from an F distribution.
Samples are drawn from an F distribution with specified parameters,
`dfnum` (degrees of freedom in numerator) and `dfden` (degrees of
freedom in denominator), where both parameters must be greater than
zero.
The random variate of the F distribution (also known as the
Fisher distribution) is a continuous probability distribution
that arises in ANOVA tests, and is the ratio of two chi-square
variates.
Parameters
----------
dfnum : float or array_like of floats
Degrees of freedom in numerator, must be > 0.
dfden : float or array_like of float
Degrees of freedom in denominator, must be > 0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``dfnum`` and ``dfden`` are both scalars.
Otherwise, ``np.broadcast(dfnum, dfden).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Fisher distribution.
See Also
--------
scipy.stats.f : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The F statistic is used to compare in-group variances to between-group
variances. Calculating the distribution depends on the sampling, and
so it is a function of the respective degrees of freedom in the
problem. The variable `dfnum` is the number of samples minus one, the
between-groups degrees of freedom, while `dfden` is the within-groups
degrees of freedom, the sum of the number of samples in each group
minus the number of groups.
References
----------
.. [1] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill,
Fifth Edition, 2002.
.. [2] Wikipedia, "F-distribution",
https://en.wikipedia.org/wiki/F-distribution
Examples
--------
An example from Glantz[1], pp 47-40:
Two groups, children of diabetics (25 people) and children from people
without diabetes (25 controls). Fasting blood glucose was measured,
case group had a mean value of 86.1, controls had a mean value of
82.2. Standard deviations were 2.09 and 2.49 respectively. Are these
data consistent with the null hypothesis that the parents diabetic
status does not affect their children's blood glucose levels?
Calculating the F statistic from the data gives a value of 36.01.
Draw samples from the distribution:
>>> dfnum = 1. # between group degrees of freedom
>>> dfden = 48. # within groups degrees of freedom
>>> s = bm.random.f(dfnum, dfden, 1000)
The lower bound for the top 1% of the samples is :
>>> np.sort(s)[-10]
7.61988120985 # random
So there is about a 1% chance that the F statistic will exceed 7.62,
the measured value is 36, so the null hypothesis is rejected at the 1%
level.
"""
return DEFAULT.f(dfnum, dfden, size, key=key)
[docs]
def hypergeometric(ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Hypergeometric distribution.
Samples are drawn from a hypergeometric distribution with specified
parameters, `ngood` (ways to make a good selection), `nbad` (ways to make
a bad selection), and `nsample` (number of items sampled, which is less
than or equal to the sum ``ngood + nbad``).
Parameters
----------
ngood : int or array_like of ints
Number of ways to make a good selection. Must be nonnegative.
nbad : int or array_like of ints
Number of ways to make a bad selection. Must be nonnegative.
nsample : int or array_like of ints
Number of items sampled. Must be at least 1 and at most
``ngood + nbad``.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if `ngood`, `nbad`, and `nsample`
are all scalars. Otherwise, ``np.broadcast(ngood, nbad, nsample).size``
samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized hypergeometric distribution. Each
sample is the number of good items within a randomly selected subset of
size `nsample` taken from a set of `ngood` good items and `nbad` bad items.
See Also
--------
scipy.stats.hypergeom : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The probability density for the Hypergeometric distribution is
.. math:: P(x) = \frac{\binom{g}{x}\binom{b}{n-x}}{\binom{g+b}{n}},
where :math:`0 \le x \le n` and :math:`n-b \le x \le g`
for P(x) the probability of ``x`` good results in the drawn sample,
g = `ngood`, b = `nbad`, and n = `nsample`.
Consider an urn with black and white marbles in it, `ngood` of them
are black and `nbad` are white. If you draw `nsample` balls without
replacement, then the hypergeometric distribution describes the
distribution of black balls in the drawn sample.
Note that this distribution is very similar to the binomial
distribution, except that in this case, samples are drawn without
replacement, whereas in the Binomial case samples are drawn with
replacement (or the sample space is infinite). As the sample space
becomes large, this distribution approaches the binomial.
References
----------
.. [1] Lentner, Marvin, "Elementary Applied Statistics", Bogden
and Quigley, 1972.
.. [2] Weisstein, Eric W. "Hypergeometric Distribution." From
MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/HypergeometricDistribution.html
.. [3] Wikipedia, "Hypergeometric distribution",
https://en.wikipedia.org/wiki/Hypergeometric_distribution
Examples
--------
Draw samples from the distribution:
>>> ngood, nbad, nsamp = 100, 2, 10
# number of good, number of bad, and number of samples
>>> s = bm.random.hypergeometric(ngood, nbad, nsamp, 1000)
>>> from matplotlib.pyplot import hist
>>> hist(s)
# note that it is very unlikely to grab both bad items
Suppose you have an urn with 15 white and 15 black marbles.
If you pull 15 marbles at random, how likely is it that
12 or more of them are one color?
>>> s = bm.random.hypergeometric(15, 15, 15, 100000)
>>> sum(s>=12)/100000. + sum(s<=3)/100000.
# answer = 0.003 ... pretty unlikely!
"""
return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key)
[docs]
def logseries(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a logarithmic series distribution.
Samples are drawn from a log series distribution with specified
shape parameter, 0 <= ``p`` < 1.
Parameters
----------
p : float or array_like of floats
Shape parameter for the distribution. Must be in the range [0, 1).
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``p`` is a scalar. Otherwise,
``np.array(p).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized logarithmic series distribution.
See Also
--------
scipy.stats.logser : probability density function, distribution or
cumulative density function, etc.
Notes
-----
The probability density for the Log Series distribution is
.. math:: P(k) = \frac{-p^k}{k \ln(1-p)},
where p = probability.
The log series distribution is frequently used to represent species
richness and occurrence, first proposed by Fisher, Corbet, and
Williams in 1943 [2]. It may also be used to model the numbers of
occupants seen in cars [3].
References
----------
.. [1] Buzas, Martin A.; Culver, Stephen J., Understanding regional
species diversity through the log series distribution of
occurrences: BIODIVERSITY RESEARCH Diversity & Distributions,
Volume 5, Number 5, September 1999 , pp. 187-195(9).
.. [2] Fisher, R.A,, A.S. Corbet, and C.B. Williams. 1943. The
relation between the number of species and the number of
individuals in a random sample of an animal population.
Journal of Animal Ecology, 12:42-58.
.. [3] D. J. Hand, F. Daly, D. Lunn, E. Ostrowski, A Handbook of Small
Data Sets, CRC Press, 1994.
.. [4] Wikipedia, "Logarithmic distribution",
https://en.wikipedia.org/wiki/Logarithmic_distribution
Examples
--------
Draw samples from the distribution:
>>> a = .6
>>> s = bm.random.logseries(a, 10000)
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s)
# plot against distribution
>>> def logseries(k, p):
... return -p**k/(k*np.log(1-p))
>>> plt.plot(bins, logseries(bins, a)*count.max()/
... logseries(bins, a).max(), 'r')
>>> plt.show()
"""
return DEFAULT.logseries(p, size, key=key)
[docs]
def multinomial(n, pvals, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a multinomial distribution.
The multinomial distribution is a multivariate generalization of the
binomial distribution. Take an experiment with one of ``p``
possible outcomes. An example of such an experiment is throwing a dice,
where the outcome can be 1 through 6. Each sample drawn from the
distribution represents `n` such experiments. Its values,
``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the
outcome was ``i``.
Parameters
----------
n : int
Number of experiments.
pvals : sequence of floats, length p
Probabilities of each of the ``p`` different outcomes. These
must sum to 1 (however, the last element is always assumed to
account for the remaining probability, as long as
``sum(pvals[:-1]) <= 1)``.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
Returns
-------
out : ndarray
The drawn samples, of shape *size*, if that was provided. If not,
the shape is ``(N,)``.
In other words, each entry ``out[i,j,...,:]`` is an N-dimensional
value drawn from the distribution.
Examples
--------
Throw a dice 20 times:
>>> bm.random.multinomial(20, [1/6.]*6, size=1)
array([[4, 1, 7, 5, 2, 1]]) # random
It landed 4 times on 1, once on 2, etc.
Now, throw the dice 20 times, and 20 times again:
>>> bm.random.multinomial(20, [1/6.]*6, size=2)
array([[3, 4, 3, 3, 4, 3], # random
[2, 4, 3, 4, 0, 7]])
For the first run, we threw 3 times 1, 4 times 2, etc. For the second,
we threw 2 times 1, 4 times 2, etc.
A loaded die is more likely to land on number 6:
>>> bm.random.multinomial(100, [1/7.]*5 + [2/7.])
array([11, 16, 14, 17, 16, 26]) # random
The probability inputs should be normalized. As an implementation
detail, the value of the last entry is ignored and assumed to take
up any leftover probability mass, but this should not be relied on.
A biased coin which has twice as much weight on one side as on the
other should be sampled like so:
>>> bm.random.multinomial(100, [1.0 / 3, 2.0 / 3]) # RIGHT
array([38, 62]) # random
not like:
>>> bm.random.multinomial(100, [1.0, 2.0]) # WRONG
Traceback (most recent call last):
ValueError: pvals < 0, pvals > 1 or pvals contains NaNs
"""
return DEFAULT.multinomial(n, pvals, size, key=key)
[docs]
def multivariate_normal(mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky',
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw random samples from a multivariate normal distribution.
The multivariate normal, multinormal or Gaussian distribution is a
generalization of the one-dimensional normal distribution to higher
dimensions. Such a distribution is specified by its mean and
covariance matrix. These parameters are analogous to the mean
(average or "center") and variance (standard deviation, or "width,"
squared) of the one-dimensional normal distribution.
Parameters
----------
mean : 1-D array_like, of length N
Mean of the N-dimensional distribution.
cov : 2-D array_like, of shape (N, N)
Covariance matrix of the distribution. It must be symmetric and
positive-semidefinite for proper sampling.
size : int or tuple of ints, optional
Given a shape of, for example, ``(m,n,k)``, ``m*n*k`` samples are
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``.
If no shape is specified, a single (`N`-D) sample is returned.
check_valid : { 'warn', 'raise', 'ignore' }, optional
Behavior when the covariance matrix is not positive semidefinite.
tol : float, optional
Tolerance when checking the singular values in covariance matrix.
cov is cast to double before the check.
Returns
-------
out : ndarray
The drawn samples, of shape *size*, if that was provided. If not,
the shape is ``(N,)``.
In other words, each entry ``out[i,j,...,:]`` is an N-dimensional
value drawn from the distribution.
Notes
-----
The mean is a coordinate in N-dimensional space, which represents the
location where samples are most likely to be generated. This is
analogous to the peak of the bell curve for the one-dimensional or
univariate normal distribution.
Covariance indicates the level to which two variables vary together.
From the multivariate normal distribution, we draw N-dimensional
samples, :math:`X = [x_1, x_2, ... x_N]`. The covariance matrix
element :math:`C_{ij}` is the covariance of :math:`x_i` and :math:`x_j`.
The element :math:`C_{ii}` is the variance of :math:`x_i` (i.e. its
"spread").
Instead of specifying the full covariance matrix, popular
approximations include:
- Spherical covariance (`cov` is a multiple of the identity matrix)
- Diagonal covariance (`cov` has non-negative elements, and only on
the diagonal)
This geometrical property can be seen in two dimensions by plotting
generated data-points:
>>> mean = [0, 0]
>>> cov = [[1, 0], [0, 100]] # diagonal covariance
Diagonal covariance means that points are oriented along x or y-axis:
>>> import matplotlib.pyplot as plt
>>> x, y = bm.random.multivariate_normal(mean, cov, 5000).T
>>> plt.plot(x, y, 'x')
>>> plt.axis('equal')
>>> plt.show()
Note that the covariance matrix must be positive semidefinite (a.k.a.
nonnegative-definite). Otherwise, the behavior of this method is
undefined and backwards compatibility is not guaranteed.
References
----------
.. [1] Papoulis, A., "Probability, Random Variables, and Stochastic
Processes," 3rd ed., New York: McGraw-Hill, 1991.
.. [2] Duda, R. O., Hart, P. E., and Stork, D. G., "Pattern
Classification," 2nd ed., New York: Wiley, 2001.
Examples
--------
>>> mean = (1, 2)
>>> cov = [[1, 0], [0, 1]]
>>> x = bm.random.multivariate_normal(mean, cov, (3, 3))
>>> x.shape
(3, 3, 2)
Here we generate 800 samples from the bivariate normal distribution
with mean [0, 0] and covariance matrix [[6, -3], [-3, 3.5]]. The
expected variances of the first and second components of the sample
are 6 and 3.5, respectively, and the expected correlation
coefficient is -3/sqrt(6*3.5) ≈ -0.65465.
>>> cov = np.array([[6, -3], [-3, 3.5]])
>>> pts = bm.random.multivariate_normal([0, 0], cov, size=800)
Check that the mean, covariance, and correlation coefficient of the
sample are close to the expected values:
>>> pts.mean(axis=0)
array([ 0.0326911 , -0.01280782]) # may vary
>>> np.cov(pts.T)
array([[ 5.96202397, -2.85602287],
[-2.85602287, 3.47613949]]) # may vary
>>> np.corrcoef(pts.T)[0, 1]
-0.6273591314603949 # may vary
We can visualize this data with a scatter plot. The orientation
of the point cloud illustrates the negative correlation of the
components of this sample.
>>> import matplotlib.pyplot as plt
>>> plt.plot(pts[:, 0], pts[:, 1], '.', alpha=0.5)
>>> plt.axis('equal')
>>> plt.grid()
>>> plt.show()
"""
return DEFAULT.multivariate_normal(mean, cov, size, method, key=key)
[docs]
def negative_binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a negative binomial distribution.
Samples are drawn from a negative binomial distribution with specified
parameters, `n` successes and `p` probability of success where `n`
is > 0 and `p` is in the interval [0, 1].
Parameters
----------
n : float or array_like of floats
Parameter of the distribution, > 0.
p : float or array_like of floats
Parameter of the distribution, >= 0 and <=1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``n`` and ``p`` are both scalars.
Otherwise, ``np.broadcast(n, p).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized negative binomial distribution,
where each sample is equal to N, the number of failures that
occurred before a total of n successes was reached.
Notes
-----
The probability mass function of the negative binomial distribution is
.. math:: P(N;n,p) = \frac{\Gamma(N+n)}{N!\Gamma(n)}p^{n}(1-p)^{N},
where :math:`n` is the number of successes, :math:`p` is the
probability of success, :math:`N+n` is the number of trials, and
:math:`\Gamma` is the gamma function. When :math:`n` is an integer,
:math:`\frac{\Gamma(N+n)}{N!\Gamma(n)} = \binom{N+n-1}{N}`, which is
the more common form of this term in the pmf. The negative
binomial distribution gives the probability of N failures given n
successes, with a success on the last trial.
If one throws a die repeatedly until the third time a "1" appears,
then the probability distribution of the number of non-"1"s that
appear before the third "1" is a negative binomial distribution.
References
----------
.. [1] Weisstein, Eric W. "Negative Binomial Distribution." From
MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/NegativeBinomialDistribution.html
.. [2] Wikipedia, "Negative binomial distribution",
https://en.wikipedia.org/wiki/Negative_binomial_distribution
Examples
--------
Draw samples from the distribution:
A real world example. A company drills wild-cat oil
exploration wells, each with an estimated probability of
success of 0.1. What is the probability of having one success
for each successive well, that is what is the probability of a
single success after drilling 5 wells, after 6 wells, etc.?
>>> s = bm.random.negative_binomial(1, 0.1, 100000)
>>> for i in range(1, 11): # doctest: +SKIP
... probability = sum(s<i) / 100000.
... print(i, "wells drilled, probability of one success =", probability)
"""
return DEFAULT.negative_binomial(n, p, size, key=key)
[docs]
def noncentral_chisquare(df, nonc, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a noncentral chi-square distribution.
The noncentral :math:`\chi^2` distribution is a generalization of
the :math:`\chi^2` distribution.
Parameters
----------
df : float or array_like of floats
Degrees of freedom, must be > 0.
nonc : float or array_like of floats
Non-centrality, must be non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``df`` and ``nonc`` are both scalars.
Otherwise, ``np.broadcast(df, nonc).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized noncentral chi-square distribution.
Notes
-----
The probability density function for the noncentral Chi-square
distribution is
.. math:: P(x;df,nonc) = \sum^{\infty}_{i=0}
\frac{e^{-nonc/2}(nonc/2)^{i}}{i!}
P_{Y_{df+2i}}(x),
where :math:`Y_{q}` is the Chi-square with q degrees of freedom.
References
----------
.. [1] Wikipedia, "Noncentral chi-squared distribution"
https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution
Examples
--------
Draw values from the distribution and plot the histogram
>>> import matplotlib.pyplot as plt
>>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000),
... bins=200, density=True)
>>> plt.show()
Draw values from a noncentral chisquare with very small noncentrality,
and compare to a chisquare.
>>> plt.figure()
>>> values = plt.hist(bm.random.noncentral_chisquare(3, .0000001, 100000),
... bins=np.arange(0., 25, .1), density=True)
>>> values2 = plt.hist(bm.random.chisquare(3, 100000),
... bins=np.arange(0., 25, .1), density=True)
>>> plt.plot(values[1][0:-1], values[0]-values2[0], 'ob')
>>> plt.show()
Demonstrate how large values of non-centrality lead to a more symmetric
distribution.
>>> plt.figure()
>>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000),
... bins=200, density=True)
>>> plt.show()
"""
return DEFAULT.noncentral_chisquare(df, nonc, size, key=key)
[docs]
def noncentral_f(dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the noncentral F distribution.
Samples are drawn from an F distribution with specified parameters,
`dfnum` (degrees of freedom in numerator) and `dfden` (degrees of
freedom in denominator), where both parameters > 1.
`nonc` is the non-centrality parameter.
Parameters
----------
dfnum : float or array_like of floats
Numerator degrees of freedom, must be > 0.
dfden : float or array_like of floats
Denominator degrees of freedom, must be > 0.
nonc : float or array_like of floats
Non-centrality parameter, the sum of the squares of the numerator
means, must be >= 0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``dfnum``, ``dfden``, and ``nonc``
are all scalars. Otherwise, ``np.broadcast(dfnum, dfden, nonc).size``
samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized noncentral Fisher distribution.
Notes
-----
When calculating the power of an experiment (power = probability of
rejecting the null hypothesis when a specific alternative is true) the
non-central F statistic becomes important. When the null hypothesis is
true, the F statistic follows a central F distribution. When the null
hypothesis is not true, then it follows a non-central F statistic.
References
----------
.. [1] Weisstein, Eric W. "Noncentral F-Distribution."
From MathWorld--A Wolfram Web Resource.
http://mathworld.wolfram.com/NoncentralF-Distribution.html
.. [2] Wikipedia, "Noncentral F-distribution",
https://en.wikipedia.org/wiki/Noncentral_F-distribution
Examples
--------
In a study, testing for a specific alternative to the null hypothesis
requires use of the Noncentral F distribution. We need to calculate the
area in the tail of the distribution that exceeds the value of the F
distribution for the null hypothesis. We'll plot the two probability
distributions for comparison.
>>> dfnum = 3 # between group deg of freedom
>>> dfden = 20 # within groups degrees of freedom
>>> nonc = 3.0
>>> nc_vals = bm.random.noncentral_f(dfnum, dfden, nonc, 1000000)
>>> NF = np.histogram(nc_vals, bins=50, density=True)
>>> c_vals = bm.random.f(dfnum, dfden, 1000000)
>>> F = np.histogram(c_vals, bins=50, density=True)
>>> import matplotlib.pyplot as plt
>>> plt.plot(F[1][1:], F[0])
>>> plt.plot(NF[1][1:], NF[0])
>>> plt.show()
"""
return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key)
[docs]
def power(a,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draws samples in [0, 1] from a power distribution with positive
exponent a - 1.
Also known as the power function distribution.
Parameters
----------
a : float or array_like of floats
Parameter of the distribution. Must be non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized power distribution.
Raises
------
ValueError
If a <= 0.
Notes
-----
The probability density function is
.. math:: P(x; a) = ax^{a-1}, 0 \le x \le 1, a>0.
The power function distribution is just the inverse of the Pareto
distribution. It may also be seen as a special case of the Beta
distribution.
It is used, for example, in modeling the over-reporting of insurance
claims.
References
----------
.. [1] Christian Kleiber, Samuel Kotz, "Statistical size distributions
in economics and actuarial sciences", Wiley, 2003.
.. [2] Heckert, N. A. and Filliben, James J. "NIST Handbook 148:
Dataplot Reference Manual, Volume 2: Let Subcommands and Library
Functions", National Institute of Standards and Technology
Handbook Series, June 2003.
https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/powpdf.pdf
Examples
--------
Draw samples from the distribution:
>>> a = 5. # shape
>>> samples = 1000
>>> s = bm.random.power(a, samples)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, bins=30)
>>> x = np.linspace(0, 1, 100)
>>> y = a*x**(a-1.)
>>> normed_y = samples*np.diff(bins)[0]*y
>>> plt.plot(x, normed_y)
>>> plt.show()
Compare the power function distribution to the inverse of the Pareto.
>>> from scipy import stats # doctest: +SKIP
>>> rvs = bm.random.power(5, 1000000)
>>> rvsp = bm.random.pareto(5, 1000000)
>>> xx = np.linspace(0,1,100)
>>> powpdf = stats.powerlaw.pdf(xx,5) # doctest: +SKIP
>>> plt.figure()
>>> plt.hist(rvs, bins=50, density=True)
>>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
>>> plt.title('bm.random.power(5)')
>>> plt.figure()
>>> plt.hist(1./(1.+rvsp), bins=50, density=True)
>>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
>>> plt.title('inverse of 1 + bm.random.pareto(5)')
>>> plt.figure()
>>> plt.hist(1./(1.+rvsp), bins=50, density=True)
>>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP
>>> plt.title('inverse of stats.pareto(5)')
"""
return DEFAULT.power(a, size, key=key)
[docs]
def rayleigh(scale=1.0,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Rayleigh distribution.
The :math:`\chi` and Weibull distributions are generalizations of the
Rayleigh.
Parameters
----------
scale : float or array_like of floats, optional
Scale, also equals the mode. Must be non-negative. Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``scale`` is a scalar. Otherwise,
``np.array(scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Rayleigh distribution.
Notes
-----
The probability density function for the Rayleigh distribution is
.. math:: P(x;scale) = \frac{x}{scale^2}e^{\frac{-x^2}{2 \cdotp scale^2}}
The Rayleigh distribution would arise, for example, if the East
and North components of the wind velocity had identical zero-mean
Gaussian distributions. Then the wind speed would have a Rayleigh
distribution.
References
----------
.. [1] Brighton Webs Ltd., "Rayleigh Distribution,"
https://web.archive.org/web/20090514091424/http://brighton-webs.co.uk:80/distributions/rayleigh.asp
.. [2] Wikipedia, "Rayleigh distribution"
https://en.wikipedia.org/wiki/Rayleigh_distribution
Examples
--------
Draw values from the distribution and plot the histogram
>>> from matplotlib.pyplot import hist
>>> values = hist(bm.random.rayleigh(3, 100000), bins=200, density=True)
Wave heights tend to follow a Rayleigh distribution. If the mean wave
height is 1 meter, what fraction of waves are likely to be larger than 3
meters?
>>> meanvalue = 1
>>> modevalue = np.sqrt(2 / np.pi) * meanvalue
>>> s = bm.random.rayleigh(modevalue, 1000000)
The percentage of waves larger than 3 meters is:
>>> 100.*sum(s>3)/1000000.
0.087300000000000003 # random
"""
return DEFAULT.rayleigh(scale, size, key=key)
[docs]
def triangular(size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from the triangular distribution over the
interval ``[left, right]``.
The triangular distribution is a continuous probability
distribution with lower limit left, peak at mode, and upper
limit right. Unlike the other distributions, these parameters
directly define the shape of the pdf.
Parameters
----------
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``left``, ``mode``, and ``right``
are all scalars. Otherwise, ``np.broadcast(left, mode, right).size``
samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized triangular distribution.
Notes
-----
The probability density function for the triangular distribution is
.. math:: P(x;l, m, r) = \begin{cases}
\frac{2(x-l)}{(r-l)(m-l)}& \text{for $l \leq x \leq m$},\\
\frac{2(r-x)}{(r-l)(r-m)}& \text{for $m \leq x \leq r$},\\
0& \text{otherwise}.
\end{cases}
The triangular distribution is often used in ill-defined
problems where the underlying distribution is not known, but
some knowledge of the limits and mode exists. Often it is used
in simulations.
References
----------
.. [1] Wikipedia, "Triangular distribution"
https://en.wikipedia.org/wiki/Triangular_distribution
Examples
--------
Draw values from the distribution and plot the histogram:
>>> import matplotlib.pyplot as plt
>>> h = plt.hist(bm.random.triangular(-3, 0, 8, 100000), bins=200,
... density=True)
>>> plt.show()
"""
return DEFAULT.triangular(size, key=key)
[docs]
def vonmises(mu,
kappa,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a von Mises distribution.
Samples are drawn from a von Mises distribution with specified mode
(mu) and dispersion (kappa), on the interval [-pi, pi].
The von Mises distribution (also known as the circular normal
distribution) is a continuous probability distribution on the unit
circle. It may be thought of as the circular analogue of the normal
distribution.
Parameters
----------
mu : float or array_like of floats
Mode ("center") of the distribution.
kappa : float or array_like of floats
Dispersion of the distribution, has to be >=0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mu`` and ``kappa`` are both scalars.
Otherwise, ``np.broadcast(mu, kappa).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized von Mises distribution.
See Also
--------
scipy.stats.vonmises : probability density function, distribution, or
cumulative density function, etc.
Notes
-----
The probability density for the von Mises distribution is
.. math:: p(x) = \frac{e^{\kappa cos(x-\mu)}}{2\pi I_0(\kappa)},
where :math:`\mu` is the mode and :math:`\kappa` the dispersion,
and :math:`I_0(\kappa)` is the modified Bessel function of order 0.
The von Mises is named for Richard Edler von Mises, who was born in
Austria-Hungary, in what is now the Ukraine. He fled to the United
States in 1939 and became a professor at Harvard. He worked in
probability theory, aerodynamics, fluid mechanics, and philosophy of
science.
References
----------
.. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of
Mathematical Functions with Formulas, Graphs, and Mathematical
Tables, 9th printing," New York: Dover, 1972.
.. [2] von Mises, R., "Mathematical Theory of Probability
and Statistics", New York: Academic Press, 1964.
Examples
--------
Draw samples from the distribution:
>>> mu, kappa = 0.0, 4.0 # mean and dispersion
>>> s = bm.random.vonmises(mu, kappa, 1000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> from scipy.special import i0 # doctest: +SKIP
>>> plt.hist(s, 50, density=True)
>>> x = np.linspace(-np.pi, np.pi, num=51)
>>> y = np.exp(kappa*np.cos(x-mu))/(2*np.pi*i0(kappa)) # doctest: +SKIP
>>> plt.plot(x, y, linewidth=2, color='r') # doctest: +SKIP
>>> plt.show()
"""
return DEFAULT.vonmises(mu, kappa, size, key=key)
[docs]
def wald(mean,
scale,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Wald, or inverse Gaussian, distribution.
As the scale approaches infinity, the distribution becomes more like a
Gaussian. Some references claim that the Wald is an inverse Gaussian
with mean equal to 1, but this is by no means universal.
The inverse Gaussian distribution was first studied in relationship to
Brownian motion. In 1956 M.C.K. Tweedie used the name inverse Gaussian
because there is an inverse relationship between the time to cover a
unit distance and distance covered in unit time.
Parameters
----------
mean : float or array_like of floats
Distribution mean, must be > 0.
scale : float or array_like of floats
Scale parameter, must be > 0.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mean`` and ``scale`` are both scalars.
Otherwise, ``np.broadcast(mean, scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Wald distribution.
Notes
-----
The probability density function for the Wald distribution is
.. math:: P(x;mean,scale) = \sqrt{\frac{scale}{2\pi x^3}}e^
\frac{-scale(x-mean)^2}{2\cdotp mean^2x}
As noted above the inverse Gaussian distribution first arise
from attempts to model Brownian motion. It is also a
competitor to the Weibull for use in reliability modeling and
modeling stock returns and interest rate processes.
References
----------
.. [1] Brighton Webs Ltd., Wald Distribution,
https://web.archive.org/web/20090423014010/http://www.brighton-webs.co.uk:80/distributions/wald.asp
.. [2] Chhikara, Raj S., and Folks, J. Leroy, "The Inverse Gaussian
Distribution: Theory : Methodology, and Applications", CRC Press,
1988.
.. [3] Wikipedia, "Inverse Gaussian distribution"
https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution
Examples
--------
Draw values from the distribution and plot the histogram:
>>> import matplotlib.pyplot as plt
>>> h = plt.hist(bm.random.wald(3, 2, 100000), bins=200, density=True)
>>> plt.show()
"""
return DEFAULT.wald(mean, scale, size, key=key)
[docs]
def weibull(a,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Weibull distribution.
Draw samples from a 1-parameter Weibull distribution with the given
shape parameter `a`.
.. math:: X = (-ln(U))^{1/a}
Here, U is drawn from the uniform distribution over (0,1].
The more common 2-parameter Weibull, including a scale parameter
:math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`.
.. note::
New code should use the ``weibull`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
a : float or array_like of floats
Shape parameter of the distribution. Must be nonnegative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Weibull distribution.
Notes
-----
The Weibull (or Type III asymptotic extreme value distribution
for smallest values, SEV Type III, or Rosin-Rammler
distribution) is one of a class of Generalized Extreme Value
(GEV) distributions used in modeling extreme value problems.
This class includes the Gumbel and Frechet distributions.
The probability density for the Weibull distribution is
.. math:: p(x) = \frac{a}
{\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a},
where :math:`a` is the shape and :math:`\lambda` the scale.
The function has its peak (the mode) at
:math:`\lambda(\frac{a-1}{a})^{1/a}`.
When ``a = 1``, the Weibull distribution reduces to the exponential
distribution.
References
----------
.. [1] Waloddi Weibull, Royal Technical University, Stockholm,
1939 "A Statistical Theory Of The Strength Of Materials",
Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939,
Generalstabens Litografiska Anstalts Forlag, Stockholm.
.. [2] Waloddi Weibull, "A Statistical Distribution Function of
Wide Applicability", Journal Of Applied Mechanics ASME Paper
1951.
.. [3] Wikipedia, "Weibull distribution",
https://en.wikipedia.org/wiki/Weibull_distribution
Examples
--------
Draw samples from the distribution:
>>> a = 5. # shape
>>> s = brainpy.math.random.weibull(a, 1000)
Display the histogram of the samples, along with
the probability density function:
>>> import matplotlib.pyplot as plt
>>> x = np.arange(1,100.)/50.
>>> def weib(x,n,a):
... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a)
>>> count, bins, ignored = plt.hist(brainpy.math.random.weibull(5.,1000))
>>> x = np.arange(1,100.)/50.
>>> scale = count.max()/weib(x, 1., 5.).max()
>>> plt.plot(x, weib(x, 1., 5.)*scale)
>>> plt.show()
"""
return DEFAULT.weibull(a, size, key=key)
[docs]
def weibull_min(a,
scale=None,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a Weibull distribution.
The scipy counterpart is `scipy.stats.weibull_min`.
Args:
scale: The scale parameter of the distribution.
concentration: The concentration parameter of the distribution.
shape: The shape added to the parameters loc and scale broadcastable shape.
dtype: The type used for samples.
key: a PRNG key or a seed.
Returns:
A jnp.array of samples.
"""
return DEFAULT.weibull_min(a, scale, size, key=key)
[docs]
def zipf(a,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
r"""
Draw samples from a Zipf distribution.
Samples are drawn from a Zipf distribution with specified parameter
`a` > 1.
The Zipf distribution (also known as the zeta distribution) is a
discrete probability distribution that satisfies Zipf's law: the
frequency of an item is inversely proportional to its rank in a
frequency table.
.. note::
New code should use the ``zipf`` method of a ``default_rng()``
instance instead; please see the :ref:`random-quick-start`.
Parameters
----------
a : float or array_like of floats
Distribution parameter. Must be greater than 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``a`` is a scalar. Otherwise,
``np.array(a).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized Zipf distribution.
See Also
--------
scipy.stats.zipf : probability density function, distribution, or
cumulative density function, etc.
Notes
-----
The probability density for the Zipf distribution is
.. math:: p(k) = \frac{k^{-a}}{\zeta(a)},
for integers :math:`k \geq 1`, where :math:`\zeta` is the Riemann Zeta
function.
It is named for the American linguist George Kingsley Zipf, who noted
that the frequency of any word in a sample of a language is inversely
proportional to its rank in the frequency table.
References
----------
.. [1] Zipf, G. K., "Selected Studies of the Principle of Relative
Frequency in Language," Cambridge, MA: Harvard Univ. Press,
1932.
Examples
--------
Draw samples from the distribution:
>>> a = 4.0
>>> n = 20000
>>> s = brainpy.math.random.zipf(a, n)
Display the histogram of the samples, along with
the expected histogram based on the probability
density function:
>>> import matplotlib.pyplot as plt
>>> from scipy.special import zeta # doctest: +SKIP
`bincount` provides a fast histogram for small integers.
>>> count = np.bincount(s)
>>> k = np.arange(1, s.max() + 1)
>>> plt.bar(k, count[1:], alpha=0.5, label='sample count')
>>> plt.plot(k, n*(k**-a)/zeta(a), 'k.-', alpha=0.5,
... label='expected count') # doctest: +SKIP
>>> plt.semilogy()
>>> plt.grid(alpha=0.4)
>>> plt.legend()
>>> plt.title(f'Zipf sample, a={a}, size={n}')
>>> plt.show()
"""
return DEFAULT.zipf(a, size, key=key)
[docs]
def maxwell(size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample from a one sided Maxwell distribution.
The scipy counterpart is `scipy.stats.maxwell`.
Args:
key: a PRNG key.
size: The shape of the returned samples.
dtype: The type used for samples.
Returns:
A jnp.array of samples, of shape `shape`.
"""
return DEFAULT.maxwell(size, key=key)
[docs]
def t(df,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample Student’s t random values.
Parameters
----------
df: float, array_like
A float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
size: optional, int, tuple of int
A tuple of non-negative integers specifying the result shape.
Must be broadcast-compatible with `df`. The default (None) produces a result shape equal to `df.shape`.
Returns
-------
out: array_like
The sampled value.
"""
return DEFAULT.t(df, size, key=key)
[docs]
def orthogonal(n: int,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample uniformly from the orthogonal group `O(n)`.
Parameters
----------
n: int
An integer indicating the resulting dimension.
size: optional, int, tuple of int
The batch dimensions of the result.
Returns
-------
out: Array
The sampled results.
"""
return DEFAULT.orthogonal(n, size, key=key)
[docs]
def loggamma(a,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample log-gamma random values.
Parameters
----------
a: float, array_like
A float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
size: optional, int, tuple of int
A tuple of nonnegative integers specifying the result shape.
Must be broadcast-compatible with `a`. The default (None) produces a result shape equal to `a.shape`.
Returns
-------
out: array_like
The sampled results.
"""
return DEFAULT.loggamma(a, size, key=key)
[docs]
def categorical(logits,
axis: int = -1,
size: Optional[Union[int, Sequence[int]]] = None,
key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Sample random values from categorical distributions.
Args:
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that `softmax(logits, axis)` gives the corresponding probabilities.
axis: Axis along which logits belong to the same categorical distribution.
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
key: a PRNG key used as the random key.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
"""
return DEFAULT.categorical(logits, axis, size, key=key)
[docs]
def rand_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Similar to ``rand_like`` in torch.
Returns a tensor with the same size as input that is filled with random
numbers from a uniform distribution on the interval ``[0, 1)``.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.rand_like(input, dtype=dtype, key=key)
[docs]
def randn_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Similar to ``randn_like`` in torch.
Returns a tensor with the same size as ``input`` that is filled with
random numbers from a normal distribution with mean 0 and variance 1.
Args:
input: the ``size`` of input will determine size of the output tensor.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.randn_like(input, dtype=dtype, key=key)
[docs]
def randint_like(input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
"""Similar to ``randint_like`` in torch.
Returns a tensor with the same shape as Tensor ``input`` filled with
random integers generated uniformly between ``low`` (inclusive) and ``high`` (exclusive).
Args:
input: the ``size`` of input will determine size of the output tensor.
low: Lowest integer to be drawn from the distribution. Default: 0.
high: One above the highest integer to be drawn from the distribution.
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
key: the seed or key for the random.
Returns:
The random data.
"""
return DEFAULT.randint_like(input=input, low=low, high=high, dtype=dtype, key=key)
for __k in dir(RandomState):
__t = getattr(RandomState, __k)
if not __k.startswith('__') and callable(__t) and (not __t.__doc__):
__r = globals().get(__k, None)
if __r is not None and callable(__r):
__t.__doc__ = __r.__doc__