scan

Contents

scan#

class brainpy.math.scan(body_fun, init, operands, reverse=False, unroll=1, remat=False, progress_bar=False)[source]#

scan control flow with Variable.

Similar to jax.lax.scan.

New in version 2.4.7.

All returns in body function will be gathered as the return of the whole loop.

Parameters:
  • body_fun (callable) – A Python function to be scanned. This function accepts one argument and returns one output. The argument denotes a slice of operands along its leading axis, and that output represents a slice of the return value.

  • init (Any) – An initial loop carry value of type c, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by body_fun.

  • operands (Any) – The value over which to scan along the leading axis, where operands can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. If body function body_func receives multiple arguments, operands should be a tuple/list whose length is equal to the number of arguments.

  • remat (bool) – Make fun recompute internal linearization points when differentiated.

  • reverse (bool) – Optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

  • unroll (int) – Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.

  • progress_bar (bool) –

    Whether we use the progress bar to report the running progress.

    New in version 2.4.2.

Returns:

outs – The stacked outputs of body_fun when scanned over the leading axis of the inputs.

Return type:

Any