scan#
- class brainpy.math.scan(body_fun, init, operands, reverse=False, unroll=1, remat=False, progress_bar=False)[source]#
scan
control flow withVariable
.Similar to
jax.lax.scan
.New in version 2.4.7.
All returns in body function will be gathered as the return of the whole loop.
- Parameters:
body_fun (callable) – A Python function to be scanned. This function accepts one argument and returns one output. The argument denotes a slice of
operands
along its leading axis, and that output represents a slice of the return value.init (Any) – An initial loop carry value of type
c
, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned bybody_fun
.operands (Any) – The value over which to scan along the leading axis, where
operands
can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. If body function body_func receives multiple arguments, operands should be a tuple/list whose length is equal to the number of arguments.remat (bool) – Make
fun
recompute internal linearization points when differentiated.reverse (bool) – Optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both
xs
and inys
.unroll (int) – Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.
progress_bar (bool) –
Whether we use the progress bar to report the running progress.
New in version 2.4.2.
- Returns:
outs – The stacked outputs of
body_fun
when scanned over the leading axis of the inputs.- Return type:
Any