How to debug in BrainPy#

Colab Open in Kaggle

import jax
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

bp.__version__
'2.4.2'

jax.disable_jit() context#

To debug your model on BrainPy, users should turn off the JIT mode by using jax.disable_jit().

@bm.jit
def f1(a):
    print(f'call, a = {a} ...')
    return a

With JIT mode, the above code will produce:

f1(1.)
call, a = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> ...
call, a = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> ...
Array(1., dtype=float32, weak_type=True)

The first call is used to infer the dynamical variables (brainpy.math.Variable) used in this function. The second call is used to compile the whole function. Note that, with JIT mode, we cannot get the concrete values in the function.

We can turn off the JIT with jax.disable_jit() context manager.

with jax.disable_jit():
    f1(1.)
call, a = 1.0 ...

As you can see, the above code prints the concrete value used in the model. In such a way, ones can integrate standard debugging tools in your model design.

jax.disable_jit() works for most brainpy transformations, including:

  • brainpy.math.jit()

  • brainpy.math.grad()

  • brainpy.math.vector_grad()

  • brainpy.math.while_loop()

  • brainpy.math.cond()

  • brainpy.math.ifelse()

brainpy.DSRunner(..., jit=False)#

If users are using brainpy.DSRunner, you can initialize brainpy.DSRunner(..., jit=False) to disable JIT compilation when simulating a brain dynamics model.

brainpy.for_loop(..., jit=False)#

Similarly, if users are using brainpy.for_loop, you can put a jit=False argument into the for_loop transformation, then the JIT compilation will be removed.