brainpy.math.random.truncated_normal(lower, upper, size=None, loc=0.0, scale=1.0, dtype=<class 'float'>, key=None)[source]#

Sample truncated standard normal random values with given shape and dtype.

Notes

This distribution is the normal distribution centered on `loc` (default 0), with standard deviation `scale` (default 1), and clipped at `a`, `b` standard deviations to the left, right (respectively) from `loc`. If `myclip_a` and `myclip_b` are clip values in the sample space (as opposed to the number of standard deviations) then they can be converted to the required form according to:

```a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
```
Parameters:
• lower (float, ndarray) – A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with `upper`.

• upper (float, ndarray) – A float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with `lower`.

• loc (optional, float, ndarray) – Mean (“centre”) of the distribution before truncating. Note that the mean of the truncated distribution will not be exactly equal to `loc`.

• size (optional, list of int, tuple of int) – A tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with `lower` and `upper`. The default (None) produces a result shape by broadcasting `lower` and `upper`.

• scale (float, ndarray) – Standard deviation (spread or “width”) of the distribution. Must be non-negative. Default is 1.

• dtype (optional) – The float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

• key (jax.Array) – The key for random generator. Consistent with the jax’s random paradigm.

Returns:

out – A random array with the specified dtype and shape given by `shape` if `shape` is not None, or else by broadcasting `lower` and `upper`. Returns values in the open interval `(lower, upper)`.

Return type:

Array