# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numbers
from typing import Union, Sequence, Any, Callable, Optional
import brainstate
import jax
import jax.numpy as jnp
from brainpy.math.ndarray import Array
from ._utils import warp_to_no_state_input_output
def _unwrap_operand_leaf(x):
"""Replace a ``State``/``Variable`` or BrainPy ``Array`` leaf with its raw value.
``brainstate.transform.*`` rejects ``State`` objects passed as operands, and feeding a
BrainPy ``Array`` through brainstate's loop primitives round-trips it through
``tree_unflatten`` (which reconstructs from ``ShapedArray`` avals and fails) inside a
JAX trace. Unwrapping both to the underlying ``jax.Array`` avoids both problems while
leaving any other operand type untouched.
"""
if isinstance(x, (brainstate.State, Array)):
return x.value
return x
def _unwrap_state_operands(operands):
"""Unwrap ``brainstate.State`` (e.g. :py:class:`~.Variable`) and :py:class:`~.Array`
leaves in ``operands`` to their raw ``jax.Array`` values before forwarding to brainstate.
"""
return jax.tree.map(
_unwrap_operand_leaf,
operands,
is_leaf=lambda x: isinstance(x, (brainstate.State, Array)),
)
__all__ = [
'cond',
'ifelse',
'for_loop',
'scan',
'while_loop',
]
def _convert_progress_bar_to_pbar(
progress_bar: Union[bool, brainstate.transform.ProgressBar, int, None]
) -> Optional[brainstate.transform.ProgressBar]:
"""Convert progress_bar parameter to brainstate pbar format.
Parameters
----------
progress_bar : bool, ProgressBar, int, None
The progress_bar parameter value.
Returns
-------
pbar : ProgressBar or None
The converted ProgressBar instance or None.
Raises
------
TypeError
If progress_bar is not a valid type.
"""
if progress_bar is False or progress_bar is None:
return None
elif progress_bar is True:
return brainstate.transform.ProgressBar()
elif isinstance(progress_bar, int):
# Support brainstate convention: int means freq parameter
return brainstate.transform.ProgressBar(freq=progress_bar)
elif isinstance(progress_bar, brainstate.transform.ProgressBar):
return progress_bar
else:
raise TypeError(
f"progress_bar must be bool, int, or ProgressBar instance, "
f"got {type(progress_bar).__name__}"
)
[docs]
def cond(
pred: bool,
true_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
false_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
operands: Any = (),
):
"""Simple conditional statement (if-else) with instance of :py:class:`~.Variable`.
>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros(2))
>>> b = bm.Variable(bm.ones(2))
>>> def true_f(): a.value += 1
>>> def false_f(): b.value -= 1
>>>
>>> bm.cond(True, true_f, false_f)
>>> a, b
Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32)
>>>
>>> bm.cond(False, true_f, false_f)
>>> a, b
Variable([1., 1.], dtype=float32), Variable([0., 0.], dtype=float32)
Parameters
----------
pred : bool
Boolean scalar type, indicating which branch function to apply.
true_fun : callable, ArrayType, float, int, bool
Function to be applied if ``pred`` is True.
This function must receive one arguement for ``operands``.
false_fun : callable, ArrayType, float, int, bool
Function to be applied if ``pred`` is False.
This function must receive one arguement for ``operands``.
operands : Any
Operands (A) input to branching function depending on ``pred``. The type
can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
dyn_vars : optional, Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
Returns
-------
res : Any
The conditional results.
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = _unwrap_state_operands(operands)
# ``true_fun``/``false_fun`` may be constants (array/number), per the
# documented contract. Wrap any non-callable branch into a callable that
# ignores ``*operands`` and returns the (unwrapped) constant, mirroring the
# handling in ``ifelse``. Otherwise brainstate would try to *call* the
# constant and raise ``TypeError: '<type>' object is not callable``.
def _make_branch(branch):
if callable(branch):
return warp_to_no_state_input_output(branch)
const = _unwrap_operand_leaf(branch)
return warp_to_no_state_input_output(lambda *args: const)
return brainstate.transform.cond(
pred,
_make_branch(true_fun),
_make_branch(false_fun),
*operands
)
[docs]
def ifelse(
conditions: Union[bool, Sequence[bool]],
branches: Sequence[Any],
operands: Any = None,
):
"""``If-else`` control flows looks like native Pythonic programming.
Examples
--------
>>> import brainpy.math as bm
>>> def f(a):
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
>>> branches=[lambda: 1,
>>> lambda: 2,
>>> lambda: 3,
>>> lambda: 4,
>>> lambda: 5])
>>> f(1)
4
>>> # or, it can be expressed as:
>>> def f(a):
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
>>> branches=[1, 2, 3, 4, 5])
>>> f(3)
3
Parameters
----------
conditions : bool, sequence of bool
The boolean conditions.
branches : Any
The branches, at least has two elements. Elements can be functions,
arrays, or numbers. The number of ``branches`` and ``conditions`` has
the relationship of `len(branches) == len(conditions) + 1`.
Each branch should receive one arguement for ``operands``.
operands : optional, Any
The operands for each branch.
show_code : bool
Whether show the formatted code.
dyn_vars : Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
Returns
-------
res : Any
The results of the control flow.
"""
if operands is None:
operands = ()
elif not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = _unwrap_state_operands(operands)
# Convert non-callable branches to callables
def make_callable(branch):
if callable(branch):
return warp_to_no_state_input_output(branch)
else:
return warp_to_no_state_input_output(lambda *args: branch)
branches = [make_callable(branch) for branch in branches]
# A single condition may be passed as a bare scalar bool/array (the
# docstring types ``conditions`` as ``bool, sequence of bool``). Normalise
# it into a one-element list so it flows through the conversion below;
# otherwise ``brainstate.transform.ifelse`` would call ``len()`` on the
# scalar and raise ``TypeError: object ... has no len()``.
if not isinstance(conditions, (list, tuple)):
conditions = [conditions]
# Convert if-elif-else chain to mutually exclusive conditions
if isinstance(conditions, (list, tuple)) and len(conditions) > 0:
conditions = list(conditions)
# Convert to mutually exclusive conditions for brainstate
exclusive_conditions = []
for i, cond in enumerate(conditions):
if i == 0:
exclusive_conditions.append(cond)
else:
# This condition is true AND all previous conditions are false
prev_conds_false = jnp.logical_not(conditions[0])
for j in range(1, i):
prev_conds_false = prev_conds_false & jnp.logical_not(conditions[j])
exclusive_conditions.append(cond & prev_conds_false)
# If we have equal number of branches and conditions, the last branch is the default case
if len(branches) == len(conditions):
# Replace the last condition with "all previous conditions are false"
all_false = jnp.logical_not(conditions[0])
for cond in conditions[1:-1]: # exclude the last condition
all_false = all_false & jnp.logical_not(cond)
exclusive_conditions[-1] = all_false
elif len(branches) > len(conditions):
# Add the default case (all conditions false)
all_false = jnp.logical_not(conditions[0])
for cond in conditions[1:]:
all_false = all_false & jnp.logical_not(cond)
exclusive_conditions.append(all_false)
conditions = exclusive_conditions
# BrainPy already converts the conditions into mutually exclusive form above,
# so brainstate does not need to re-check exclusivity (which would otherwise
# reject overlapping inputs at trace time).
return brainstate.transform.ifelse(conditions, branches, *operands, check_cond=False)
[docs]
def for_loop(
body_fun: Callable,
operands: Any,
reverse: bool = False,
unroll: int = 1,
jit: Optional[bool] = None,
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
):
"""``for-loop`` control flow with :py:class:`~.Variable`.
.. versionadded:: 2.1.11
.. versionchanged:: 2.3.0
``dyn_vars`` has been changed into a default argument.
Please change your call from ``for_loop(fun, dyn_vars, operands)``
to ``for_loop(fun, operands, dyn_vars)``.
All returns in body function will be gathered
as the return of the whole loop.
>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros(1))
>>> b = bm.Variable(bm.ones(1))
>>> # first example
>>> def body(x):
>>> a.value += x
>>> b.value *= x
>>> return a.value
>>> a_hist = bm.for_loop(body, operands=bm.arange(1, 5))
>>> a_hist
DeviceArray([[ 1.],
[ 3.],
[ 6.],
[10.]], dtype=float32)
>>> a
Variable([10.], dtype=float32)
>>> b
Variable([24.], dtype=float32)
>>>
>>> # another example
>>> def body(x, y):
>>> a.value += x
>>> b.value *= y
>>> return a.value
>>> a_hist = bm.for_loop(body, operands=(bm.arange(1, 5), bm.arange(2, 6)))
>>> a_hist
[[11.]
[13.]
[16.]
[20.]]
Parameters
----------
body_fun : callable
A Python function to be scanned. This function accepts one argument and returns one output.
The argument denotes a slice of ``operands`` along its leading axis, and that
output represents a slice of the return value.
operands : Any
The value over which to scan along the leading axis,
where ``operands`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
If body function `body_func` receives multiple arguments,
`operands` should be a tuple/list whose length is equal to the
number of arguments.
reverse : bool
Optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll : int
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
jit : bool
Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
.. note::
``jit=False`` is implemented via the global :py:func:`jax.disable_jit` context
manager. Consequently it has no effect when ``for_loop`` is called inside an
enclosing trace (e.g. within another jitted/scanned function): JAX is already
tracing, so the loop runs as a compiled ``scan`` regardless of this flag.
progress_bar : bool, ProgressBar, int
Whether and how to display a progress bar during execution:
- ``False`` (default): No progress bar
- ``True``: Display progress bar with default settings
- ``ProgressBar`` instance: Display progress bar with custom settings
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)
For advanced customization, create a :py:class:`brainpy.math.ProgressBar` instance:
>>> import brainpy.math as bm
>>> # Custom update frequency
>>> pbar = bm.ProgressBar(freq=10)
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Custom description
>>> pbar = bm.ProgressBar(desc="Processing data")
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Update exactly 20 times during execution
>>> pbar = bm.ProgressBar(count=20)
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Integer shorthand (equivalent to ProgressBar(freq=10))
>>> result = bm.for_loop(body_fun, operands, progress_bar=10)
.. versionadded:: 2.4.2
.. versionchanged:: 2.7.3
Now accepts ProgressBar instances and integers for advanced customization.
dyn_vars : Variable, sequence of Variable, dict
The instances of :py:class:`~.Variable`.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
Returns
-------
outs : Any
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = _unwrap_state_operands(operands)
# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)
# Handle jit parameter
# Note: JAX's scan doesn't support zero-length inputs in disable_jit mode.
# For zero-length inputs, we need to use JIT mode even when jit=False.
should_disable_jit = False
if jit is False:
# Check if any operand (over the whole pytree) has a zero-length leading axis.
leaves = jax.tree.leaves(operands)
is_zero_length = any(
getattr(leaf, 'ndim', 0) > 0 and leaf.shape[0] == 0
for leaf in leaves
)
if is_zero_length:
# Use JIT mode for zero-length inputs to avoid JAX limitation
import warnings
warnings.warn(
"for_loop with jit=False and zero-length input detected. "
"Using JIT mode to avoid JAX's disable_jit limitation with zero-length scans.",
UserWarning
)
else:
should_disable_jit = True
if should_disable_jit:
with jax.disable_jit():
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=pbar,
)
else:
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=pbar,
)
[docs]
def scan(
body_fun: Callable,
init: Any,
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
):
"""``scan`` control flow with :py:class:`~.Variable`.
Similar to ``jax.lax.scan``.
.. versionadded:: 2.4.7
All returns in body function will be gathered
as the return of the whole loop.
Parameters
----------
body_fun : callable
A Python function to be scanned. This function accepts one argument and returns one output.
The argument denotes a slice of ``operands`` along its leading axis, and that
output represents a slice of the return value.
init : Any
An initial loop carry value of type ``c``, which can be a scalar, array, or any pytree
(nested Python tuple/list/dict) thereof, representing the initial loop carry value.
This value must have the same structure as the first element of the pair returned
by ``body_fun``.
operands : Any
The value over which to scan along the leading axis,
where ``operands`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
If body function `body_func` receives multiple arguments,
`operands` should be a tuple/list whose length is equal to the
number of arguments.
remat : bool
Make ``fun`` recompute internal linearization points when differentiated.
reverse : bool
Optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll : int
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
progress_bar : bool, ProgressBar, int
Whether and how to display a progress bar during execution:
- ``False`` (default): No progress bar
- ``True``: Display progress bar with default settings
- ``ProgressBar`` instance: Display progress bar with custom settings
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)
See :py:func:`for_loop` for detailed examples of ProgressBar usage.
.. versionadded:: 2.4.2
.. versionchanged:: 2.7.3
Now accepts ProgressBar instances and integers for advanced customization.
Returns
-------
outs : tuple
A two-element tuple ``(final_carry, stacked_ys)``:
- ``final_carry``: the loop carry value returned by the last iteration of
``body_fun`` (same structure as ``init``).
- ``stacked_ys``: the per-iteration outputs of ``body_fun`` stacked along a
new leading axis.
"""
# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)
init = _unwrap_state_operands(init)
operands = _unwrap_state_operands(operands)
return brainstate.transform.scan(
warp_to_no_state_input_output(body_fun),
init=init,
xs=operands,
reverse=reverse,
unroll=unroll,
pbar=pbar,
)
[docs]
def while_loop(
body_fun: Callable,
cond_fun: Callable,
operands: Any,
):
"""``while-loop`` control flow with :py:class:`~.Variable`.
.. versionchanged:: 2.3.0
``dyn_vars`` has been changed into a default argument.
Please change your call from ``while_loop(f1, f2, dyn_vars, operands)``
to ``while_loop(f1, f2, operands, dyn_vars)``.
Note the diference between ``for_loop`` and ``while_loop``:
1. ``while_loop`` does not support accumulating history values.
2. The returns on the body function of ``for_loop`` represent the values to stack at one moment.
However, the functional returns of body function in ``while_loop`` represent the operands'
values at the next moment, meaning that the body function of ``while_loop`` defines the
updating rule of how the operands are updated.
>>> import brainpy.math as bm
>>>
>>> a = bm.Variable(bm.zeros(1))
>>> b = bm.Variable(bm.ones(1))
>>>
>>> def cond(x, y):
>>> return x < 6.
>>>
>>> def body(x, y):
>>> a.value += x
>>> b.value *= y
>>> return x + b[0], y + 1.
>>>
>>> res = bm.while_loop(body, cond, operands=(1., 1.))
>>> res
(10.0, 4.0)
.. versionadded:: 2.1.11
Parameters
----------
body_fun : callable
A function which define the updating logic. It receives one argument for ``operands``, without returns.
cond_fun : callable
A function which define the stop condition. It receives one argument for ``operands``,
with one boolean value return.
operands : Any
The operands for ``body_fun`` and ``cond_fun`` functions.
dyn_vars : Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = tuple(operands)
operands = _unwrap_state_operands(operands)
def body(x):
r = body_fun(*x)
if r is None:
# Classic brainpy idiom: ``body_fun`` mutates ``Variable`` state in place
# and returns ``None`` (often with empty ``operands``). brainstate's
# ``while_loop`` tracks that state automatically and the loop condition is
# driven by the mutated state, so the operands are threaded through
# unchanged. Returning ``x`` preserves this behaviour while still allowing
# a functional ``body_fun`` to return the updated operands explicitly.
return x
return r
return brainstate.transform.while_loop(
warp_to_no_state_input_output(lambda x: cond_fun(*x)),
warp_to_no_state_input_output(body),
operands
)