Gotchas of BrainPy Transformations#

Colab Open in Kaggle

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.7.8'

BrainPy provides a novel concept for object-oriented transformations based brainpy.math.Variable. However, this kind of transformations faces several gotchas:

1. Variable that will be changed cannot be functional arguments#

This will not work too for the new oo transformations.

@bm.jit
def f(a, b):
  a.value = b

a = bm.Variable(bm.ones(1))
b = bm.Variable(bm.ones(1) * 10)
f(a, b)

try:
  assert bm.allclose(a, b)
  print('a equals to b.')
except:
  print('a is not equal to b.')
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/lib/python3.13/site-packages/ipykernel_launcher.py:18
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/miniconda3/lib/python3.13/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelapp.py:758, in start()
    757 try:
--> 758     self.io_loop.start()
    759 except KeyboardInterrupt:

File ~/miniconda3/lib/python3.13/site-packages/tornado/platform/asyncio.py:211, in start()
    210 def start(self) -> None:
--> 211     self.asyncio_loop.run_forever()

File ~/miniconda3/lib/python3.13/asyncio/base_events.py:683, in run_forever()
    682 while True:
--> 683     self._run_once()
    684     if self._stopping:

File ~/miniconda3/lib/python3.13/asyncio/base_events.py:2050, in _run_once()
   2049     else:
-> 2050         handle._run()
   2051 handle = None

File ~/miniconda3/lib/python3.13/asyncio/events.py:89, in _run()
     88 try:
---> 89     self._context.run(self._callback, *self._args)
     90 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:621, in shell_main()
    620 async with asyncio_lock:
--> 621     await self.dispatch_shell(msg, subshell_id=subshell_id)

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:478, in dispatch_shell()
    477     if inspect.isawaitable(result):
--> 478         await result
    479 except Exception:

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/ipkernel.py:372, in execute_request()
    371 """Override for cell output - cell reconciliation."""
--> 372 await super().execute_request(stream, ident, parent)

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/kernelbase.py:834, in execute_request()
    833 if inspect.isawaitable(reply_content):
--> 834     reply_content = await reply_content
    835 else:

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/ipkernel.py:464, in do_execute()
    463 if accepts_params["cell_id"]:
--> 464     res = shell.run_cell(
    465         code,
    466         store_history=store_history,
    467         silent=silent,
    468         cell_id=cell_id,
    469     )
    470 else:

File ~/miniconda3/lib/python3.13/site-packages/ipykernel/zmqshell.py:663, in run_cell()
    662 self._last_traceback = None
--> 663 return super().run_cell(*args, **kwargs)

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3123, in run_cell()
   3122 try:
-> 3123     result = self._run_cell(
   3124         raw_cell, store_history, silent, shell_futures, cell_id
   3125     )
   3126 finally:

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3178, in _run_cell()
   3177 try:
-> 3178     result = runner(coro)
   3179 except BaseException as e:

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
    127 try:
--> 128     coro.send(None)
    129 except StopIteration as exc:

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3400, in run_cell_async()
   3397 interactivity = "none" if silent else self.ast_node_interactivity
-> 3400 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3401        interactivity=interactivity, compiler=compiler, result=result)
   3403 self.last_execution_succeeded = not has_raised

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3641, in run_ast_nodes()
   3640     asy = compare(code)
-> 3641 if await self.run_code(code, result, async_=asy):
   3642     return True

File ~/miniconda3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3701, in run_code()
   3700     else:
-> 3701         exec(code_obj, self.user_global_ns, self.user_ns)
   3702 finally:
   3703     # Reset our crash handler in place

Cell In[2], line 5
      3   a.value = b
----> 5 a = bm.Variable(bm.ones(1))
      6 b = bm.Variable(bm.ones(1) * 10)

File /mnt/d/codes/projects/BrainPy/.claude/worktrees/notebooks-fix/brainpy/math/object_transform/variables.py:85, in __init__()
     84 Array.__init__(self, value, dtype=dtype)
---> 85 brainstate.State.__init__(self, value)
     87 # check batch axis

JaxStackTraceBeforeTransformation: ValueError: Inputs/outputs for brainstate transformations cannot be an instance of State. But we got [1.]

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[3], line 1
----> 1 f(a, b)
      3 try:
      4   assert bm.allclose(a, b)

File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_jit.py:131, in _get_jitted_fun.<locals>.jitted_fn(*args, **params)
    128     return fun.fun(*args, **params)
    130 # compile the function and get the state trace
--> 131 state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
    132 read_state_vals = state_trace.get_read_state_values(True)
    134 # call the jitted function

File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:651, in StatefulFunction.get_state_trace(self, compile_if_miss, *args, **kwargs)
    633 def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
    634     """
    635     Read the state trace of the function.
    636 
   (...)    649         The state trace of the function.
    650     """
--> 651     cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
    652     return self.get_state_trace_by_cache(cache_key)

File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:417, in StatefulFunction.get_arg_cache_key(self, compile_if_miss, *args, **kwargs)
    378 def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> CacheKey:
    379     """
    380     Compute the cache key for the given arguments.
    381 
   (...)    414         >>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
    415     """
--> 417     cache_key = get_arg_cache_key(
    418         self.static_argnums,
    419         self.static_argnames,
    420         args,
    421         kwargs,
    422     )
    423     if compile_if_miss:
    424         compilation = self._compilation_cache.get(cache_key)

File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:164, in get_arg_cache_key(static_argnums, static_argnames, args, kwargs, fn_to_check)
    160 # Mirror the kwargs branch below: skip the State check when ``fn_to_check``
    161 # is None. Without this guard ``jax.tree.map(None, ...)`` raises a confusing
    162 # ``TypeError: 'NoneType' object is not callable`` on the positional args.
    163 if fn_to_check is not None:
--> 164     jax.tree.map(fn_to_check, dyn_args, is_leaf=lambda x: isinstance(x, State))
    165 dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
    167 # kwargs -- the State check must run BEFORE abstractification, which
    168 # flattens a State into its value leaves and would hide it

File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree.py:156, in map(f, tree, is_leaf, *rest)
    116 def map(f: Callable[..., Any],
    117         tree: Any,
    118         *rest: Any,
    119         is_leaf: Callable[[Any], bool] | None = None) -> Any:
    120   """Maps a multi-input function over pytree args to produce a new pytree.
    121 
    122   Args:
   (...)    154     - :func:`jax.tree.reduce`
    155   """
--> 156   return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree_util.py:399, in tree_map(f, tree, is_leaf, *rest)
    397   err = next(_prefix_error((), tree, r2, is_leaf), None)  # type: ignore
    398   raise (err('tree_map tree') if err is not None else e) from None
--> 399 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniconda3/lib/python3.13/site-packages/jax/_src/tree_util.py:399, in <genexpr>(.0)
    397   err = next(_prefix_error((), tree, r2, is_leaf), None)  # type: ignore
    398   raise (err('tree_map tree') if err is not None else e) from None
--> 399 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniconda3/lib/python3.13/site-packages/brainstate/transform/_make_jaxpr.py:138, in _check_input_ouput(x)
    136 def _check_input_ouput(x):
    137     if isinstance(x, State):
--> 138         x.raise_error_with_source_info(
    139             ValueError(
    140                 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
    141                 f'But we got {x}'
    142             )
    143         )

File ~/miniconda3/lib/python3.13/site-packages/brainstate/_state.py:692, in State.raise_error_with_source_info(self, error)
    690 name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
    691 with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
--> 692     raise error

ValueError: Inputs/outputs for brainstate transformations cannot be an instance of State. But we got [1.]
a
Variable(
  value=ShapedArray(float32[1]),
  _batch_axis=None,
  axis_names=None
)

All Variables should be used in a global context.

Instead, this works:

@bm.jit
def f(b):
  a.value = b

a = bm.Variable(bm.ones(1))
b = bm.ones(1) * 10
f(b)

a
Variable(
  value=ShapedArray(float32[1]),
  _batch_axis=None,
  axis_names=None
)

2. Functions to be transformed are called twice#

The core mechanism of any brainpy transformation is that it firsts calls the function to automatically find all Variables used in the model, and then it calls the function again to compile the model with the found Variables.

Therefore, any function that the user create will be called more than twice.

@bm.jit
def f(inp):
  print('calling f ...')
  return inp

@bm.jit
def g(inp):
  print('calling g ...')
  return f(inp)

Taking the above function as an example, when we use this function, we will get:

g(1.)
calling g ...
calling f ...
Array(1., dtype=float32, weak_type=True)

It sequentially calls f and g to infer all dynamical variables (instances of Variable) used in these two functions. So we got first two lines of calling g ... and calling f.

Then, it compiles the two functions, so that we got next two lines of calling g ... and calling f.

Note that this property may get what are not correct in the Python level variables. For example, when we use a global variable to record the number of times the function called:

num = [0]

@bm.jit
def h(inp):
  num[0] += 1
  return inp
h(1.)
Array(1., dtype=float32, weak_type=True)

Although we called the function h once, we got the number of 2.

num
[1]