brainpy.math.jit#

brainpy.math.jit(func=None, dyn_vars=None, child_objs=None, static_argnums=None, static_argnames=None, device=None, inline=False, keep_unused=False, abstracted_axes=None)[source]#

JIT (Just-In-Time) compilation for class objects.

This function has the same ability to Just-In-Time compile a pure function, but it can also JIT compile a brainpy.DynamicalSystem, or a brainpy.Base object, or a bounded method for a brainpy.Base object.

Note

There are several notes when using JIT compilation.

  1. Avoid using scalar in a Variable, TrainVar, etc.

For example,

>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> class Test(bp.BrainPyObject):
>>>   def __init__(self):
>>>     super(Test, self).__init__()
>>>     self.a = bm.Variable(1.)  # Avoid! DO NOT USE!
>>>   def __call__(self, *args, **kwargs):
>>>     self.a += 1.

The above usage is deprecated, because it may cause several errors. Instead, we recommend you define the scalar value variable as:

>>> class Test(bp.BrainPyObject):
>>>   def __init__(self):
>>>     super(Test, self).__init__()
>>>     self.a = bm.Variable(bm.array([1.]))  # use array to wrap a scalar is recommended
>>>   def __call__(self, *args, **kwargs):
>>>     self.a += 1.

Here, a ndarray is recommended to used to update the variable a.

  1. jit compilation in brainpy.math does not support static_argnums. Instead, users should use static_argnames, and call the jitted function with keywords like jitted_func(arg1=var1, arg2=var2). For example,

>>> def f(a, b, c=1.):
>>>   if c > 0.: return a + b
>>>   else: return a * b
>>>
>>> # ERROR! https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
>>> bm.jit(f)(1, 2, 0)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where
concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)
>>> # this is right
>>> bm.jit(f, static_argnames='c')(1, 2, 0)
DeviceArray(2, dtype=int32, weak_type=True)

Examples

You can JIT a brainpy.DynamicalSystem

>>> import brainpy as bp
>>>
>>> class LIF(bp.NeuGroup):
>>>   pass
>>> lif = bp.math.jit(LIF(10))

You can JIT a brainpy.Base object with __call__() implementation.

>>> mlp = bp.layers.GRU(100, 200)
>>> jit_mlp = bp.math.jit(mlp)

You can also JIT a bounded method of a brainpy.Base object.

>>> class Hello(bp.BrainPyObject):
>>>   def __init__(self):
>>>     super(Hello, self).__init__()
>>>     self.a = bp.math.Variable(bp.math.array(10.))
>>>     self.b = bp.math.Variable(bp.math.array(2.))
>>>   def transform(self):
>>>     return self.a ** self.b
>>>
>>> test = Hello()
>>> bp.math.jit(test.transform)

Further, you can JIT a normal function, just used like in JAX.

>>> @bp.math.jit
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>>   return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)
Parameters:
  • func (Base, function, callable) – The instance of Base or a function.

  • dyn_vars (optional, dict, sequence of Variable, Variable) – These variables will be changed in the function, or needed in the computation.

  • child_objs (optional, dict, sequence of BrainPyObject, BrainPyObject) –

    The children objects used in the target function.

    New in version 2.3.1.

  • static_argnames (optional, str, list, tuple, dict) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • device (optional, Any) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

Returns:

func – A callable jitted function, set up for just-in-time compilation.

Return type:

callable