jit

Contents

jit#

class brainpy.math.jit(func=None, static_argnums=None, static_argnames=None, donate_argnums=(), inline=False, keep_unused=False, abstracted_axes=None, dyn_vars=None, child_objs=None, **kwargs)[source]#

JIT (Just-In-Time) compilation for BrainPy computation.

This function has the same ability to just-in-time compile a pure function, but it can also JIT compile a brainpy.DynamicalSystem, or a brainpy.BrainPyObject object.

Examples

You can JIT any object in which all dynamical variables are defined as Variable.

>>> import brainpy as bp
>>> class Hello(bp.BrainPyObject):
>>>   def __init__(self):
>>>     super(Hello, self).__init__()
>>>     self.a = bp.math.Variable(bp.math.array(10.))
>>>     self.b = bp.math.Variable(bp.math.array(2.))
>>>   def transform(self):
>>>     self.a *= self.b
>>>
>>> test = Hello()
>>> bp.math.jit(test.transform)

Further, you can JIT a normal function, just used like in JAX.

>>> @bp.math.jit
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>>   return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)
Parameters:
  • func (BrainPyObject, function, callable) – The instance of Base or a function.

  • static_argnums (optional, int, sequence of int) – An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

  • static_argnames (optional, str, list, tuple, dict) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums (int, sequence of int) – Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated.

  • device (optional, Any) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • backend (optional, str) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • inline (bool) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.

  • dyn_vars (optional, dict, sequence of Variable, Variable) –

    These variables will be changed in the function, or needed in the computation.

    Deprecated since version 2.4.0: No longer need to provide dyn_vars. This function is capable of automatically collecting the dynamical variables used in the target func.

  • child_objs (optional, dict, sequence of BrainPyObject, BrainPyObject) –

    The children objects used in the target function.

    Deprecated since version 2.4.0: No longer need to provide child_objs. This function is capable of automatically collecting the children objects used in the target func.

Returns:

func – A callable jitted function, set up for just-in-time compilation.

Return type:

JITTransform