Gotchas of BrainPy Transformations#
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bp.__version__
'2.7.8'
BrainPy provides a novel concept for object-oriented transformations based brainpy.math.Variable. However, this kind of transformations faces several gotchas:
1. Variable that will be changed cannot be functional arguments#
This will not work too for the new oo transformations.
@bm.jit
def f(a, b):
a.value = b
a = bm.Variable(bm.ones(1))
b = bm.Variable(bm.ones(1) * 10)
f(a, b)
try:
assert bm.allclose(a, b)
print('a equals to b.')
except:
print('a is not equal to b.')
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/miniconda3/lib/python3.13/site-packages/ipykernel_launcher.py:18
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()
File ~/miniconda3/lib/python3.13/site-packages/traitlets/config/application.py:1075, in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelapp.py:758, in start()
757 try:
--> 758 self.io_loop.start()
759 except KeyboardInterrupt:
File ~/miniconda3/lib/python3.13/site-packages/tornado/platform/asyncio.py:211, in start()
210 def start(self) -> None:
--> 211 self.asyncio_loop.run_forever()
File ~/miniconda3/lib/python3.13/asyncio/base_events.py:683, in run_forever()
682 while True:
--> 683 self._run_once()
684 if self._stopping:
File ~/miniconda3/lib/python3.13/asyncio/base_events.py:2050, in _run_once()
2049 else:
-> 2050 handle._run()
2051 handle = None
File ~/miniconda3/lib/python3.13/asyncio/events.py:89, in _run()
88 try:
---> 89 self._context.run(self._callback, *self._args)
90 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:621, in shell_main()
620 async with asyncio_lock:
--> 621 await self.dispatch_shell(msg, subshell_id=subshell_id)
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:478, in dispatch_shell()
477 if inspect.isawaitable(result):
--> 478 await result
479 except Exception:
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/ipkernel.py:372, in execute_request()
371 """Override for cell output - cell reconciliation."""
--> 372 await super().execute_request(stream, ident, parent)
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:834, in execute_request()
833 if inspect.isawaitable(reply_content):
--> 834 reply_content = await reply_content
835 else:
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/ipkernel.py:464, in do_execute()
463 if accepts_params["cell_id"]:
--> 464 res = shell.run_cell(
465 code,
466 store_history=store_history,
467 silent=silent,
468 cell_id=cell_id,
469 )
470 else:
File ~/miniconda3/lib/python3.13/site-packages/ipykernel/zmqshell.py:663, in run_cell()
662 self._last_traceback = None
--> 663 return super().run_cell(*args, **kwargs)
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3123, in run_cell()
3122 try:
-> 3123 result = self._run_cell(
3124 raw_cell, store_history, silent, shell_futures, cell_id
3125 )
3126 finally:
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3178, in _run_cell()
3177 try:
-> 3178 result = runner(coro)
3179 except BaseException as e:
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
127 try:
--> 128 coro.send(None)
129 except StopIteration as exc:
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3400, in run_cell_async()
3397 interactivity = "none" if silent else self.ast_node_interactivity
-> 3400 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3401 interactivity=interactivity, compiler=compiler, result=result)
3403 self.last_execution_succeeded = not has_raised
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3641, in run_ast_nodes()
3640 asy = compare(code)
-> 3641 if await self.run_code(code, result, async_=asy):
3642 return True
File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3701, in run_code()
3700 else:
-> 3701 exec(code_obj, self.user_global_ns, self.user_ns)
3702 finally:
3703 # Reset our crash handler in place
Cell In[2], line 5
3 a.value = b
----> 5 a = bm.Variable(bm.ones(1))
6 b = bm.Variable(bm.ones(1) * 10)
File /mnt/d/codes/projects/BrainPy/.claude/worktrees/notebooks-fix/brainpy/math/object_transform/variables.py:85, in __init__()
84 Array.__init__(self, value, dtype=dtype)
---> 85 brainstate.State.__init__(self, value)
87 # check batch axis
JaxStackTraceBeforeTransformation: ValueError: Inputs/outputs for brainstate transformations cannot be an instance of State. But we got [1.]
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[3], line 1
----> 1 f(a, b)
3 try:
4 assert bm.allclose(a, b)
File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_jit.py:131, in _get_jitted_fun.<locals>.jitted_fn(*args, **params)
128 return fun.fun(*args, **params)
130 # compile the function and get the state trace
--> 131 state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
132 read_state_vals = state_trace.get_read_state_values(True)
134 # call the jitted function
File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:651, in StatefulFunction.get_state_trace(self, compile_if_miss, *args, **kwargs)
633 def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
634 """
635 Read the state trace of the function.
636
(...) 649 The state trace of the function.
650 """
--> 651 cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
652 return self.get_state_trace_by_cache(cache_key)
File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:417, in StatefulFunction.get_arg_cache_key(self, compile_if_miss, *args, **kwargs)
378 def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> CacheKey:
379 """
380 Compute the cache key for the given arguments.
381
(...) 414 >>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
415 """
--> 417 cache_key = get_arg_cache_key(
418 self.static_argnums,
419 self.static_argnames,
420 args,
421 kwargs,
422 )
423 if compile_if_miss:
424 compilation = self._compilation_cache.get(cache_key)
File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:164, in get_arg_cache_key(static_argnums, static_argnames, args, kwargs, fn_to_check)
160 # Mirror the kwargs branch below: skip the State check when ``fn_to_check``
161 # is None. Without this guard ``jax.tree.map(None, ...)`` raises a confusing
162 # ``TypeError: 'NoneType' object is not callable`` on the positional args.
163 if fn_to_check is not None:
--> 164 jax.tree.map(fn_to_check, dyn_args, is_leaf=lambda x: isinstance(x, State))
165 dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
167 # kwargs -- the State check must run BEFORE abstractification, which
168 # flattens a State into its value leaves and would hide it
File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree.py:156, in map(f, tree, is_leaf, *rest)
116 def map(f: Callable[..., Any],
117 tree: Any,
118 *rest: Any,
119 is_leaf: Callable[[Any], bool] | None = None) -> Any:
120 """Maps a multi-input function over pytree args to produce a new pytree.
121
122 Args:
(...) 154 - :func:`jax.tree.reduce`
155 """
--> 156 return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree_util.py:399, in tree_map(f, tree, is_leaf, *rest)
397 err = next(_prefix_error((), tree, r2, is_leaf), None) # type: ignore
398 raise (err('tree_map tree') if err is not None else e) from None
--> 399 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree_util.py:399, in <genexpr>(.0)
397 err = next(_prefix_error((), tree, r2, is_leaf), None) # type: ignore
398 raise (err('tree_map tree') if err is not None else e) from None
--> 399 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:138, in _check_input_ouput(x)
136 def _check_input_ouput(x):
137 if isinstance(x, State):
--> 138 x.raise_error_with_source_info(
139 ValueError(
140 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
141 f'But we got {x}'
142 )
143 )
File ~/miniconda3/lib/python3.13/site-packages/brainstate/_state.py:692, in State.raise_error_with_source_info(self, error)
690 name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
691 with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
--> 692 raise error
ValueError: Inputs/outputs for brainstate transformations cannot be an instance of State. But we got [1.]
a
Variable(
value=ShapedArray(float32[1]),
_batch_axis=None,
axis_names=None
)
All Variables should be used in a global context.
Instead, this works:
@bm.jit
def f(b):
a.value = b
a = bm.Variable(bm.ones(1))
b = bm.ones(1) * 10
f(b)
a
Variable(
value=ShapedArray(float32[1]),
_batch_axis=None,
axis_names=None
)
2. Functions to be transformed are called twice#
The core mechanism of any brainpy transformation is that it firsts calls the function to automatically find all Variables used in the model, and then it calls the function again to compile the model with the found Variables.
Therefore, any function that the user create will be called more than twice.
@bm.jit
def f(inp):
print('calling f ...')
return inp
@bm.jit
def g(inp):
print('calling g ...')
return f(inp)
Taking the above function as an example, when we use this function, we will get:
g(1.)
calling g ...
calling f ...
Array(1., dtype=float32, weak_type=True)
It sequentially calls f and g to infer all dynamical variables (instances of Variable) used in these two functions. So we got first two lines of calling g ... and calling f.
Then, it compiles the two functions, so that we got next two lines of calling g ... and calling f.
Note that this property may get what are not correct in the Python level variables. For example, when we use a global variable to record the number of times the function called:
num = [0]
@bm.jit
def h(inp):
num[0] += 1
return inp
h(1.)
Array(1., dtype=float32, weak_type=True)
Although we called the function h once, we got the number of 2.
num
[1]