brainpy.math.random.truncated_normal

brainpy.math.random.truncated_normal#

brainpy.math.random.truncated_normal(lower, upper, size=None, loc=0.0, scale=1.0, dtype=<class 'float'>, key=None)[source]#

Sample truncated standard normal random values with given shape and dtype.

Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

Notes

This distribution is the normal distribution centered on loc (default 0), with standard deviation scale (default 1), and clipped at a, b standard deviations to the left, right (respectively) from loc. If myclip_a and myclip_b are clip values in the sample space (as opposed to the number of standard deviations) then they can be converted to the required form according to:

a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale
Parameters:
  • lower (float, ndarray) – A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper (float, ndarray) – A float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • loc (optional, float, ndarray) – Mean (“centre”) of the distribution before truncating. Note that the mean of the truncated distribution will not be exactly equal to loc.

  • size (optional, list of int, tuple of int) – A tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • loc – A float or array of floats representing the mean of the distribution. Default is 0.

  • scale (float, ndarray) – Standard deviation (spread or “width”) of the distribution. Must be non-negative. Default is 1.

  • dtype (optional) – The float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • key (jax.Array) – The key for random generator. Consistent with the jax’s random paradigm.

Returns:

out – A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

Return type:

Array