# -*- coding: utf-8 -*-
import functools
import numbers
from typing import Union, Sequence, Any, Dict, Callable, Optional
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm
from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, _as_jax_array_)
from .base import BrainPyObject, ObjectTransform
from .naming import (
get_unique_name,
get_stack_cache,
cache_stack
)
from .tools import (
eval_shape,
dynvar_deprecation,
node_deprecation,
abstract
)
from .variables import (Variable, VariableStack)
__all__ = [
'make_loop',
'make_while',
'make_cond',
'cond',
'ifelse',
'for_loop',
'scan',
'while_loop',
]
class ControlObject(ObjectTransform):
"""Object-oriented Control Flow Transformation in BrainPy.
"""
def __init__(
self,
call: Callable,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]],
repr_fun: Dict,
name=None
):
super().__init__(name=name)
self.register_implicit_vars(dyn_vars)
self._f = call
self._dyn_vars = dyn_vars
self._repr_fun = repr_fun
def __call__(self, *args, **kwargs):
return self._f(*args, **kwargs)
def __repr__(self):
name = self.__class__.__name__
format_ref = [f'{k}={tools.repr_context(tools.repr_object(v), " " * (len(name) + len(k)))}'
for k, v in self._repr_fun.items()]
splitor = ", " + " " * len(name) + "\n"
return (f'{name}({splitor.join(format_ref)}, \n' +
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)}')
def _get_scan_info(f, dyn_vars, out_vars=None, has_return=False):
# iterable variables
if isinstance(dyn_vars, dict):
dyn_vars = tuple(dyn_vars.values())
elif isinstance(dyn_vars, (tuple, list)):
dyn_vars = tuple(dyn_vars)
else:
raise ValueError(
f'"dyn_vars" does not support {type(dyn_vars)}, '
f'only support dict/list/tuple of {Array.__name__}')
for v in dyn_vars:
if not isinstance(v, Array):
raise ValueError(
f'brainpy.math.jax.make_loop only support '
f'{Array.__name__}, but got {type(v)}')
# outputs
if out_vars is None:
out_vars = ()
_, tree = tree_flatten(out_vars)
elif isinstance(out_vars, Array):
_, tree = tree_flatten(out_vars)
out_vars = (out_vars,)
elif isinstance(out_vars, dict):
_, tree = tree_flatten(out_vars)
out_vars = tuple(out_vars.values())
elif isinstance(out_vars, (tuple, list)):
_, tree = tree_flatten(out_vars)
out_vars = tuple(out_vars)
else:
raise ValueError(
f'"out_vars" does not support {type(out_vars)}, '
f'only support dict/list/tuple of {Array.__name__}')
# functions
if has_return:
def fun2scan(dyn_values, x):
for v, d in zip(dyn_vars, dyn_values): v._value = d
results = f(x)
dyn_values = [v.value for v in dyn_vars]
out_values = [v.value for v in out_vars]
return dyn_values, (out_values, results)
else:
def fun2scan(dyn_values, x):
for v, d in zip(dyn_vars, dyn_values): v._value = d
f(x)
dyn_values = [v.value for v in dyn_vars]
out_values = [v.value for v in out_vars]
return dyn_values, out_values
return fun2scan, dyn_vars, tree
[docs]
def make_loop(
body_fun: Callable,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]],
out_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
has_return: bool = False
) -> ControlObject:
"""Make a for-loop function, which iterate over inputs.
Examples
--------
>>> import brainpy.math as bm
>>>
>>> a = bm.Variable(bm.zeros(1))
>>> def f(x): a.value += 1.
>>> loop = bm.make_loop(f, dyn_vars=[a], out_vars=a)
>>> loop(bm.arange(10))
Variable([[ 1.],
[ 2.],
[ 3.],
[ 4.],
[ 5.],
[ 6.],
[ 7.],
[ 8.],
[ 9.],
[10.]], dtype=float32)
>>> b = bm.Variable(bm.zeros(1))
>>> def f(x):
>>> b.value += 1
>>> return b + 1
>>> loop = bm.make_loop(f, dyn_vars=[b], out_vars=b, has_return=True)
>>> hist_b, hist_b_plus = loop(bm.arange(10))
>>> hist_b
Variable([[ 1.],
[ 2.],
[ 3.],
[ 4.],
[ 5.],
[ 6.],
[ 7.],
[ 8.],
[ 9.],
[10.]], dtype=float32)
>>> hist_b_plus
ArrayType([[ 2.],
[ 3.],
[ 4.],
[ 5.],
[ 6.],
[ 7.],
[ 8.],
[ 9.],
[10.],
[11.]], dtype=float32)
Parameters
----------
body_fun : callable, function
A function receive one argument. This argument refers to the iterable input ``x``.
dyn_vars : dict of ArrayType, sequence of ArrayType
The dynamically changed variables, while iterate between trials.
out_vars : optional, ArrayType, dict of ArrayType, sequence of ArrayType
The variables to output their values.
has_return : bool
The function has the return values.
Returns
-------
loop_func : ControlObject
The function for loop iteration. This function receives one argument ``xs``, denoting
the input tensor which interate over the time (note ``body_fun`` receive ``x``).
"""
fun2scan, dyn_vars, tree = _get_scan_info(f=body_fun,
dyn_vars=dyn_vars,
out_vars=out_vars,
has_return=has_return)
# functions
if has_return:
def call(xs=None, length=None):
init_values = [v.value for v in dyn_vars]
try:
dyn_values, (out_values, results) = jax.lax.scan(
f=fun2scan, init=init_values, xs=xs, length=length
)
except UnexpectedTracerError as e:
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values), results
else:
def call(xs):
init_values = [v.value for v in dyn_vars]
try:
dyn_values, out_values = jax.lax.scan(f=fun2scan, init=init_values, xs=xs)
except UnexpectedTracerError as e:
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
for v, d in zip(dyn_vars, init_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values)
return ControlObject(call, dyn_vars=dyn_vars, repr_fun={'body_fun': body_fun})
[docs]
def make_while(
cond_fun,
body_fun,
dyn_vars
) -> ControlObject:
"""Make a while-loop function.
This function is similar to the ``jax.lax.while_loop``. The difference is that,
if you are using ``Variable`` in your while loop codes, this function will help
you make an easy while loop function. Note: ``cond_fun`` and ``body_fun`` do no
receive any arguments. ``cond_fun`` shoule return a boolean value. ``body_fun``
does not support return values.
Examples
--------
>>> import brainpy.math as bm
>>>
>>> a = bm.zeros(1)
>>>
>>> def cond_f(x): return a[0] < 10
>>> def body_f(x): a.value += 1.
>>>
>>> loop = bm.make_while(cond_f, body_f, dyn_vars=[a])
>>> loop()
>>> a
Array([10.], dtype=float32)
Parameters
----------
cond_fun : function, callable
A function receives one argument, but return a boolean value.
body_fun : function, callable
A function receives one argument, without any returns.
dyn_vars : dict of ArrayType, sequence of ArrayType
The dynamically changed variables, while iterate between trials.
Returns
-------
loop_func : ControlObject
The function for loop iteration, which receive one argument ``x`` for external input.
"""
# iterable variables
if isinstance(dyn_vars, dict):
dyn_vars = tuple(dyn_vars.values())
elif isinstance(dyn_vars, (tuple, list)):
dyn_vars = tuple(dyn_vars)
else:
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, '
f'only support dict/list/tuple of {Array.__name__}')
for v in dyn_vars:
if not isinstance(v, Array):
raise ValueError(f'Only support {Array.__name__}, but got {type(v)}')
def _body_fun(op):
dyn_values, static_values = op
for v, d in zip(dyn_vars, dyn_values): v._value = d
body_fun(static_values)
return [v.value for v in dyn_vars], static_values
def _cond_fun(op):
dyn_values, static_values = op
for v, d in zip(dyn_vars, dyn_values): v._value = d
return as_jax(cond_fun(static_values))
name = get_unique_name('_brainpy_object_oriented_make_while_')
def call(x=None):
dyn_init = [v.value for v in dyn_vars]
try:
dyn_values, _ = jax.lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_init, x))
except UnexpectedTracerError as e:
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
return ControlObject(call=call,
dyn_vars=dyn_vars,
repr_fun={'cond_fun': cond_fun, 'body_fun': body_fun},
name=name)
[docs]
def make_cond(
true_fun,
false_fun,
dyn_vars=None
) -> ControlObject:
"""Make a condition (if-else) function.
Examples
--------
>>> import brainpy.math as bm
>>> a = bm.zeros(2)
>>> b = bm.ones(2)
>>>
>>> def true_f(x): a.value += 1
>>> def false_f(x): b.value -= 1
>>>
>>> cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
>>> cond(True)
>>> a, b
(Array([1., 1.], dtype=float32),
Array([1., 1.], dtype=float32))
>>> cond(False)
>>> a, b
(Array([1., 1.], dtype=float32),
Array([0., 0.], dtype=float32))
Parameters
----------
true_fun : callable, function
A function receives one argument, without any returns.
false_fun : callable, function
A function receives one argument, without any returns.
dyn_vars : dict of ArrayType, sequence of ArrayType
The dynamically changed variables.
Returns
-------
cond_func : ControlObject
The condictional function receives two arguments: ``pred`` for true/false judgement
and ``x`` for external input.
"""
# iterable variables
if dyn_vars is None:
dyn_vars = []
if isinstance(dyn_vars, Array):
dyn_vars = (dyn_vars,)
elif isinstance(dyn_vars, dict):
dyn_vars = tuple(dyn_vars.values())
elif isinstance(dyn_vars, (tuple, list)):
dyn_vars = tuple(dyn_vars)
else:
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, '
f'only support dict/list/tuple of {Array.__name__}')
for v in dyn_vars:
if not isinstance(v, Array):
raise ValueError(f'Only support {Array.__name__}, but got {type(v)}')
name = get_unique_name('_brainpy_object_oriented_make_cond_')
if len(dyn_vars) > 0:
def _true_fun(op):
dyn_vals, static_vals = op
for v, d in zip(dyn_vars, dyn_vals): v._value = d
res = true_fun(static_vals)
dyn_vals = [v.value for v in dyn_vars]
return dyn_vals, res
def _false_fun(op):
dyn_vals, static_vals = op
for v, d in zip(dyn_vars, dyn_vals): v._value = d
res = false_fun(static_vals)
dyn_vals = [v.value for v in dyn_vars]
return dyn_vals, res
def call(pred, x=None):
old_values = [v.value for v in dyn_vars]
try:
dyn_values, res = jax.lax.cond(pred, _true_fun, _false_fun, (old_values, x))
except UnexpectedTracerError as e:
for v, d in zip(dyn_vars, old_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
return res
else:
def call(pred, x=None):
res = jax.lax.cond(pred, true_fun, false_fun, x)
return res
return ControlObject(call, dyn_vars, repr_fun={'true_fun': true_fun, 'false_fun': false_fun})
@functools.cache
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return new_f
def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return new_f
def _check_f(f):
if callable(f):
return _warp(f)
else:
return _warp_data(f)
def _check_sequence(a):
return isinstance(a, (list, tuple))
def _cond_transform_fun(fun, dyn_vars):
@functools.wraps(fun)
def new_fun(dyn_vals, *static_vals):
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
r = fun(*static_vals)
return {k: v.value for k, v in dyn_vars.items()}, r
return new_fun
def _get_cond_transform(dyn_vars, pred, true_fun, false_fun):
_true_fun = _cond_transform_fun(true_fun, dyn_vars)
_false_fun = _cond_transform_fun(false_fun, dyn_vars)
def call_fun(operands):
return jax.lax.cond(pred, _true_fun, _false_fun, dyn_vars.dict_data(), *operands)
return call_fun
[docs]
def cond(
pred: bool,
true_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
false_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
operands: Any = (),
# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""Simple conditional statement (if-else) with instance of :py:class:`~.Variable`.
>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros(2))
>>> b = bm.Variable(bm.ones(2))
>>> def true_f(): a.value += 1
>>> def false_f(): b.value -= 1
>>>
>>> bm.cond(True, true_f, false_f)
>>> a, b
Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32)
>>>
>>> bm.cond(False, true_f, false_f)
>>> a, b
Variable([1., 1.], dtype=float32), Variable([0., 0.], dtype=float32)
Parameters
----------
pred: bool
Boolean scalar type, indicating which branch function to apply.
true_fun: callable, ArrayType, float, int, bool
Function to be applied if ``pred`` is True.
This function must receive one arguement for ``operands``.
false_fun: callable, ArrayType, float, int, bool
Function to be applied if ``pred`` is False.
This function must receive one arguement for ``operands``.
operands: Any
Operands (A) input to branching function depending on ``pred``. The type
can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
dyn_vars: optional, Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
Returns
-------
res: Any
The conditional results.
"""
# functions
true_fun = _check_f(true_fun)
false_fun = _check_f(false_fun)
# operands
if not isinstance(operands, (tuple, list)):
operands = (operands,)
# dyn vars
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
dyn_vars = get_stack_cache((true_fun, false_fun))
if not jax.config.jax_disable_jit and dyn_vars is None:
with VariableStack() as dyn_vars:
rets = eval_shape(true_fun, *operands, with_stack=True)[1]
_ = eval_shape(false_fun, *operands, with_stack=True)
cache_stack((true_fun, false_fun), dyn_vars)
if not dyn_vars.is_first_stack():
return rets
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res
def _if_else_return1(conditions, branches, operands):
for i, pred in enumerate(conditions):
if pred:
return branches[i](*operands)
else:
return branches[-1](*operands)
def _if_else_return2(conditions, branches):
for i, pred in enumerate(conditions):
if pred:
return branches[i]
else:
return branches[-1]
def _all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == x for x in iterator)
[docs]
def ifelse(
conditions: Union[bool, Sequence[bool]],
branches: Sequence[Any],
operands: Any = None,
show_code: bool = False,
# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""``If-else`` control flows looks like native Pythonic programming.
Examples
--------
>>> import brainpy.math as bm
>>> def f(a):
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
>>> branches=[lambda: 1,
>>> lambda: 2,
>>> lambda: 3,
>>> lambda: 4,
>>> lambda: 5])
>>> f(1)
4
>>> # or, it can be expressed as:
>>> def f(a):
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
>>> branches=[1, 2, 3, 4, 5])
>>> f(3)
3
Parameters
----------
conditions: bool, sequence of bool
The boolean conditions.
branches: Any
The branches, at least has two elements. Elements can be functions,
arrays, or numbers. The number of ``branches`` and ``conditions`` has
the relationship of `len(branches) == len(conditions) + 1`.
Each branch should receive one arguement for ``operands``.
operands: optional, Any
The operands for each branch.
show_code: bool
Whether show the formatted code.
dyn_vars: Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
Returns
-------
res: Any
The results of the control flow.
"""
# checking
if not isinstance(conditions, (tuple, list)):
conditions = [conditions]
if not isinstance(conditions, (tuple, list)):
raise ValueError(f'"conditions" must be a tuple/list of boolean values. '
f'But we got {type(conditions)}: {conditions}')
if not isinstance(branches, (tuple, list)):
raise ValueError(f'"branches" must be a tuple/list. '
f'But we got {type(branches)}.')
branches = [_check_f(b) for b in branches]
if len(branches) != len(conditions) + 1:
raise ValueError(f'The numbers of branches and conditions do not match. '
f'Got len(conditions)={len(conditions)} and len(branches)={len(branches)}. '
f'We expect len(conditions) + 1 == len(branches). ')
if operands is None:
operands = tuple()
if not isinstance(operands, (tuple, list)):
operands = (operands,)
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
# format new codes
if len(conditions) == 1:
return cond(conditions[0],
branches[0],
branches[1],
operands)
else:
if jax.config.jax_disable_jit:
return _if_else_return1(conditions, branches, operands)
else:
dyn_vars = get_stack_cache(tuple(branches))
if dyn_vars is None:
with VariableStack() as dyn_vars:
rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches]
trees = [jax.tree_util.tree_structure(ret) for ret in rets]
if not _all_equal(trees):
msg = 'All returns in branches should have the same tree structure. But we got:\n'
for tree in trees:
msg += f'- {tree}\n'
raise TypeError(msg)
cache_stack(tuple(branches), dyn_vars)
if not dyn_vars.is_first_stack():
return rets[0]
branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches]
code_scope = {'conditions': conditions, 'branches': branches}
codes = ['def f(dyn_vals, *operands):',
f' f0 = branches[{len(conditions)}]']
num_cond = len(conditions) - 1
code_scope['_cond'] = jax.lax.cond
for i in range(len(conditions) - 1):
codes.append(f' f{i + 1} = lambda *r: _cond(conditions[{num_cond - i}], branches[{num_cond - i}], f{i}, *r)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, dyn_vals, *operands)')
codes = '\n'.join(codes)
if show_code:
print(codes)
exec(compile(codes.strip(), '', 'exec'), code_scope)
f = code_scope['f']
dyn_values, res = f(dyn_vars.dict_data(), *operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res
def _loop_abstractify(x):
x = abstract(x)
return jax.core.mapped_aval(x.shape[0], 0, x)
def _get_for_loop_transform(
body_fun,
dyn_vars,
bar: tqdm,
progress_bar: bool,
remat: bool,
reverse: bool,
unroll: int,
unroll_kwargs: tools.DotDict
):
@functools.wraps(body_fun)
def fun2scan(carry, x):
for k in dyn_vars.keys():
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results
if remat:
fun2scan = jax.checkpoint(fun2scan)
def call(operands):
return jax.lax.scan(f=fun2scan,
init=dyn_vars.dict_data(),
xs=operands,
reverse=reverse,
unroll=unroll)
return call
[docs]
def for_loop(
body_fun: Callable,
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
jit: Optional[bool] = None,
progress_bar: bool = False,
unroll_kwargs: Optional[Dict] = None,
# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""``for-loop`` control flow with :py:class:`~.Variable`.
.. versionadded:: 2.1.11
.. versionchanged:: 2.3.0
``dyn_vars`` has been changed into a default argument.
Please change your call from ``for_loop(fun, dyn_vars, operands)``
to ``for_loop(fun, operands, dyn_vars)``.
All returns in body function will be gathered
as the return of the whole loop.
>>> import brainpy.math as bm
>>> a = bm.Variable(bm.zeros(1))
>>> b = bm.Variable(bm.ones(1))
>>> # first example
>>> def body(x):
>>> a.value += x
>>> b.value *= x
>>> return a.value
>>> a_hist = bm.for_loop(body, operands=bm.arange(1, 5))
>>> a_hist
DeviceArray([[ 1.],
[ 3.],
[ 6.],
[10.]], dtype=float32)
>>> a
Variable([10.], dtype=float32)
>>> b
Variable([24.], dtype=float32)
>>>
>>> # another example
>>> def body(x, y):
>>> a.value += x
>>> b.value *= y
>>> return a.value
>>> a_hist = bm.for_loop(body, operands=(bm.arange(1, 5), bm.arange(2, 6)))
>>> a_hist
[[11.]
[13.]
[16.]
[20.]]
Parameters
----------
body_fun: callable
A Python function to be scanned. This function accepts one argument and returns one output.
The argument denotes a slice of ``operands`` along its leading axis, and that
output represents a slice of the return value.
operands: Any
The value over which to scan along the leading axis,
where ``operands`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
If body function `body_func` receives multiple arguments,
`operands` should be a tuple/list whose length is equal to the
number of arguments.
remat: bool
Make ``fun`` recompute internal linearization points when differentiated.
jit: bool
Whether to just-in-time compile the function.
reverse: bool
Optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: int
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
progress_bar: bool
Whether we use the progress bar to report the running progress.
.. versionadded:: 2.4.2
dyn_vars: Variable, sequence of Variable, dict
The instances of :py:class:`~.Variable`.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
unroll_kwargs: dict
The keyword arguments without unrolling.
Returns
-------
outs: Any
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
"""
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
if unroll_kwargs is None:
unroll_kwargs = dict()
unroll_kwargs = tools.DotDict(unroll_kwargs)
if not isinstance(operands, (list, tuple)):
operands = (operands,)
bar = None
if progress_bar:
num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]])
bar = tqdm(total=num_total)
if jit is None: # jax disable jit
jit = not jax.config.jax_disable_jit
stack = get_stack_cache((body_fun, unroll_kwargs))
if jit:
if stack is None:
transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar,
remat, reverse, unroll, unroll_kwargs)
# TODO: better cache mechanism?
with VariableStack() as stack:
rets = eval_shape(transform, operands)
cache_stack((body_fun, unroll_kwargs), stack) # cache
if not stack.is_first_stack():
return rets[1]
del rets
else:
stack = VariableStack()
# TODO: cache mechanism?
transform = _get_for_loop_transform(body_fun, stack, bar,
progress_bar, remat, reverse,
unroll, unroll_kwargs)
if jit:
dyn_vals, out_vals = transform(operands)
else:
with jax.disable_jit():
dyn_vals, out_vals = transform(operands)
for key in stack.keys():
stack[key]._value = dyn_vals[key]
if progress_bar:
bar.close()
del dyn_vals, stack
return out_vals
def _get_scan_transform(
body_fun: Callable,
dyn_vars: VariableStack,
bar: tqdm,
progress_bar: bool,
remat: bool,
reverse: bool,
unroll: int,
):
def fun2scan(carry, x):
dyn_vars_data, carry = carry
for k in dyn_vars.keys():
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results
if remat:
fun2scan = jax.checkpoint(fun2scan)
def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
reverse=reverse,
unroll=unroll)
return call
[docs]
def scan(
body_fun: Callable,
init: Any,
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
progress_bar: bool = False,
):
"""``scan`` control flow with :py:class:`~.Variable`.
Similar to ``jax.lax.scan``.
.. versionadded:: 2.4.7
All returns in body function will be gathered
as the return of the whole loop.
Parameters
----------
body_fun: callable
A Python function to be scanned. This function accepts one argument and returns one output.
The argument denotes a slice of ``operands`` along its leading axis, and that
output represents a slice of the return value.
init: Any
An initial loop carry value of type ``c``, which can be a scalar, array, or any pytree
(nested Python tuple/list/dict) thereof, representing the initial loop carry value.
This value must have the same structure as the first element of the pair returned
by ``body_fun``.
operands: Any
The value over which to scan along the leading axis,
where ``operands`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
If body function `body_func` receives multiple arguments,
`operands` should be a tuple/list whose length is equal to the
number of arguments.
remat: bool
Make ``fun`` recompute internal linearization points when differentiated.
reverse: bool
Optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: int
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
progress_bar: bool
Whether we use the progress bar to report the running progress.
.. versionadded:: 2.4.2
Returns
-------
outs: Any
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
"""
bar = None
if progress_bar:
num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]])
bar = tqdm(total=num_total)
stack = get_stack_cache(body_fun)
if not jax.config.jax_disable_jit and stack is None:
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
with VariableStack() as stack:
rets = eval_shape(transform, init, operands)
cache_stack(body_fun, stack) # cache
if not stack.is_first_stack():
return rets[0][1], rets[1]
del rets
stack = VariableStack() if stack is None else stack
transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll)
(dyn_vals, carry), out_vals = transform(init, operands)
for key in stack.keys():
stack[key]._value = dyn_vals[key]
if progress_bar:
bar.close()
return carry, out_vals
def _get_while_transform(cond_fun, body_fun, dyn_vars):
def _body_fun(op):
dyn_vals, old_vals = op
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
new_vals = body_fun(*old_vals)
if new_vals is None:
new_vals = old_vals
if not isinstance(new_vals, tuple):
new_vals = (new_vals,)
if isinstance(new_vals, list):
new_vals = tuple(new_vals)
return dyn_vars.dict_data(), new_vals
def _cond_fun(op):
dyn_vals, old_vals = op
for k, v in dyn_vars.items():
v._value = dyn_vals[k]
with jax.ensure_compile_time_eval():
r = cond_fun(*old_vals)
return r if isinstance(r, Array) else r
# TODO: cache mechanism?
return lambda operands: jax.lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_vars.dict_data(), operands))
[docs]
def while_loop(
body_fun: Callable,
cond_fun: Callable,
operands: Any,
# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""``while-loop`` control flow with :py:class:`~.Variable`.
.. versionchanged:: 2.3.0
``dyn_vars`` has been changed into a default argument.
Please change your call from ``while_loop(f1, f2, dyn_vars, operands)``
to ``while_loop(f1, f2, operands, dyn_vars)``.
Note the diference between ``for_loop`` and ``while_loop``:
1. ``while_loop`` does not support accumulating history values.
2. The returns on the body function of ``for_loop`` represent the values to stack at one moment.
However, the functional returns of body function in ``while_loop`` represent the operands'
values at the next moment, meaning that the body function of ``while_loop`` defines the
updating rule of how the operands are updated.
>>> import brainpy.math as bm
>>>
>>> a = bm.Variable(bm.zeros(1))
>>> b = bm.Variable(bm.ones(1))
>>>
>>> def cond(x, y):
>>> return x < 6.
>>>
>>> def body(x, y):
>>> a.value += x
>>> b.value *= y
>>> return x + b[0], y + 1.
>>>
>>> res = bm.while_loop(body, cond, operands=(1., 1.))
>>> res
(10.0, 4.0)
.. versionadded:: 2.1.11
Parameters
----------
body_fun: callable
A function which define the updating logic. It receives one argument for ``operands``, without returns.
cond_fun: callable
A function which define the stop condition. It receives one argument for ``operands``,
with one boolean value return.
operands: Any
The operands for ``body_fun`` and ``cond_fun`` functions.
dyn_vars: Variable, sequence of Variable, dict
The dynamically changed variables.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
The children objects used in the target function.
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
"""
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
if not isinstance(operands, (list, tuple)):
operands = (operands,)
stack = get_stack_cache((body_fun, cond_fun))
if not jax.config.jax_disable_jit and stack is None:
with VariableStack() as stack:
_ = eval_shape(cond_fun, *operands, with_stack=True)
rets = eval_shape(body_fun, *operands, with_stack=True)[1]
cache_stack((body_fun, cond_fun), stack)
if not stack.is_first_stack():
return rets
stack = VariableStack() if stack is None else stack
dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands)
for k, v in stack.items():
v._value = dyn_values[k]
return out