for_loop

Contents

for_loop#

class brainpy.math.for_loop(body_fun, operands, reverse=False, unroll=1, jit=None, progress_bar=False)[source]#

for-loop control flow with Variable.

Added in version 2.1.11.

Changed in version 2.3.0: dyn_vars has been changed into a default argument. Please change your call from for_loop(fun, dyn_vars, operands) to for_loop(fun, operands, dyn_vars).

All returns in body function will be gathered as the return of the whole loop.

>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros(1))
>>> b = bm.Variable(bm.ones(1))
>>> # first example
>>> def body(x):
>>>    a.value += x
>>>    b.value *= x
>>>    return a.value
>>> a_hist = bm.for_loop(body, operands=bm.arange(1, 5))
>>> a_hist
DeviceArray([[ 1.],
             [ 3.],
             [ 6.],
             [10.]], dtype=float32)
>>> a
Variable([10.], dtype=float32)
>>> b
Variable([24.], dtype=float32)
>>>
>>> # another example
>>> def body(x, y):
>>>   a.value += x
>>>   b.value *= y
>>>   return a.value
>>> a_hist = bm.for_loop(body, operands=(bm.arange(1, 5), bm.arange(2, 6)))
>>> a_hist
[[11.]
 [13.]
 [16.]
 [20.]]
Parameters:
  • body_fun (Callable) – A Python function to be scanned. This function accepts one argument and returns one output. The argument denotes a slice of operands along its leading axis, and that output represents a slice of the return value.

  • operands (Any) – The value over which to scan along the leading axis, where operands can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. If body function body_func receives multiple arguments, operands should be a tuple/list whose length is equal to the number of arguments.

  • reverse (bool) – Optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

  • unroll (int) – Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.

  • jit (Optional[bool]) –

    Whether to just-in-time compile the function. Set to False to disable JIT compilation.

    Note

    jit=False is implemented via the global jax.disable_jit() context manager. Consequently it has no effect when for_loop is called inside an enclosing trace (e.g. within another jitted/scanned function): JAX is already tracing, so the loop runs as a compiled scan regardless of this flag.

  • progress_bar (Union[bool, ProgressBar, int]) –

    Whether and how to display a progress bar during execution:

    • False (default): No progress bar

    • True: Display progress bar with default settings

    • ProgressBar instance: Display progress bar with custom settings

    • int: Display progress bar updating every N iterations (treated as freq parameter)

    For advanced customization, create a brainpy.math.ProgressBar instance:

    >>> import brainpy.math as bm
    >>> # Custom update frequency
    >>> pbar = bm.ProgressBar(freq=10)
    >>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
    >>>
    >>> # Custom description
    >>> pbar = bm.ProgressBar(desc="Processing data")
    >>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
    >>>
    >>> # Update exactly 20 times during execution
    >>> pbar = bm.ProgressBar(count=20)
    >>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
    >>>
    >>> # Integer shorthand (equivalent to ProgressBar(freq=10))
    >>> result = bm.for_loop(body_fun, operands, progress_bar=10)
    

    Added in version 2.4.2.

    Changed in version 2.7.3: Now accepts ProgressBar instances and integers for advanced customization.

  • dyn_vars (Variable, sequence of Variable, dict) –

    The instances of Variable.

    Deprecated since version 2.4.0: No longer need to provide dyn_vars. This function is capable of automatically collecting the dynamical variables used in the target func.

  • child_objs (optional, dict, sequence of BrainPyObject, BrainPyObject) –

    The children objects used in the target function.

    Added in version 2.3.1.

    Deprecated since version 2.4.0: No longer need to provide child_objs. This function is capable of automatically collecting the children objects used in the target func.

Returns:

outs – The stacked outputs of body_fun when scanned over the leading axis of the inputs.

Return type:

Any