brainpy.math.jit.jit
brainpy.math.jit.jit#
- brainpy.math.jit.jit(func, dyn_vars=None, static_argnames=None, device=None, auto_infer=True)[source]#
JIT (Just-In-Time) compilation for class objects.
This function has the same ability to Just-In-Time compile a pure function, but it can also JIT compile a
brainpy.DynamicalSystem
, or abrainpy.Base
object, or a bounded method for abrainpy.Base
object.Note
There are several notes when using JIT compilation.
Avoid using scalar in a Variable, TrainVar, etc.
For example,
>>> import brainpy as bp >>> import brainpy.math as bm >>> >>> class Test(bp.Base): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bm.Variable(1.) # Avoid! DO NOT USE! >>> def __call__(self, *args, **kwargs): >>> self.a += 1.
The above usage is deprecated, because it may cause several errors. Instead, we recommend you define the scalar value variable as:
>>> class Test(bp.Base): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bm.Variable(bm.array([1.])) # use array to wrap a scalar is recommended >>> def __call__(self, *args, **kwargs): >>> self.a += 1.
Here, a ndarray is recommended to used to update the variable
a
.jit
compilation inbrainpy.math
does not support static_argnums. Instead, users should use static_argnames, and call the jitted function with keywords likejitted_func(arg1=var1, arg2=var2)
. For example,
>>> def f(a, b, c=1.): >>> if c > 0.: return a + b >>> else: return a * b >>> >>> # ERROR! https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit >>> bm.jit(f)(1, 2, 0) jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True) >>> # this is right >>> bm.jit(f, static_argnames='c')(1, 2, 0) DeviceArray(2, dtype=int32, weak_type=True)
Examples
You can JIT a
brainpy.DynamicalSystem
>>> import brainpy as bp >>> >>> class LIF(bp.NeuGroup): >>> pass >>> lif = bp.math.jit(LIF(10))
You can JIT a
brainpy.Base
object with__call__()
implementation.>>> mlp = bp.layers.GRU(100, 200) >>> jit_mlp = bp.math.jit(mlp)
You can also JIT a bounded method of a
brainpy.Base
object.>>> class Hello(bp.Base): >>> def __init__(self): >>> super(Hello, self).__init__() >>> self.a = bp.math.Variable(bp.math.array(10.)) >>> self.b = bp.math.Variable(bp.math.array(2.)) >>> def transform(self): >>> return self.a ** self.b >>> >>> test = Hello() >>> bp.math.jit(test.transform)
Further, you can JIT a normal function, just used like in JAX.
>>> @bp.math.jit >>> def selu(x, alpha=1.67, lmbda=1.05): >>> return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)
- Parameters
func (Base, function, callable) – The instance of Base or a function.
dyn_vars (optional, dict, tuple, list, JaxArray) – These variables will be changed in the function, or needed in the computation.
static_argnames (optional, str, list, tuple, dict) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.device (optional, Any) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.auto_infer (bool) – Automatical infer the dynamical variables.
- Returns
func – A wrapped version of Base object or function, set up for just-in-time compilation.
- Return type
Any