brainpy.math.random.truncated_normal#
- brainpy.math.random.truncated_normal(lower, upper, size=None, loc=0.0, scale=1.0, dtype=<class 'float'>, key=None)[source]#
Sample truncated standard normal random values with given shape and dtype.
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Notes
This distribution is the normal distribution centered on
loc
(default 0), with standard deviationscale
(default 1), and clipped ata
,b
standard deviations to the left, right (respectively) fromloc
. Ifmyclip_a
andmyclip_b
are clip values in the sample space (as opposed to the number of standard deviations) then they can be converted to the required form according to:a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
- Parameters:
lower (float, ndarray) – A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with
upper
.upper (float, ndarray) – A float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with
lower
.loc (optional, float, ndarray) – Mean (“centre”) of the distribution before truncating. Note that the mean of the truncated distribution will not be exactly equal to
loc
.size (optional, list of int, tuple of int) – A tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with
lower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
.loc – A float or array of floats representing the mean of the distribution. Default is 0.
scale (float, ndarray) – Standard deviation (spread or “width”) of the distribution. Must be non-negative. Default is 1.
dtype (optional) – The float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
key (jax.Array) – The key for random generator. Consistent with the jax’s random paradigm.
- Returns:
out – A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
. Returns values in the open interval(lower, upper)
.- Return type: