brainpy.math.parallels.vmap#

brainpy.math.parallels.vmap(func, dyn_vars=None, batched_vars=None, in_axes=0, out_axes=0, axis_name=None, reduce_func=None, auto_infer=False)[source]#

Vectorization compilation for class objects.

Vectorized compile a function or a module to run in parallel on a single device.

Examples

Parameters
  • func (Base, function, callable) – The function or the module to compile.

  • dyn_vars (dict, sequence) –

  • batched_vars (dict) –

  • in_axes (optional, int, sequence of int) –

    Specify which input array axes to map over. If each positional argument to obj_or_func is an array, then in_axes can be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to obj_or_func. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range [-ndim, ndim) for each array, where ndim is the number of dimensions (axes) of the corresponding input array.

    If the positional arguments to obj_or_func are container types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to obj_or_func.

    At least one positional argument must have in_axes not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.

    Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0).

  • out_axes (optional, int, tuple/list/dict) – Indicate where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None out_axes specification. Axis integers must be in the range [-ndim, ndim) for each output array, where ndim is the number of dimensions (axes) of the array returned by the vmap()-ed function, which is one more than the number of dimensions (axes) of the corresponding array returned by obj_or_func.

  • axis_name (optional) –

Returns

obj_or_func – Batched/vectorized version of obj_or_func with arguments that correspond to those of obj_or_func, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of obj_or_func, but with extra array axes at positions indicated by out_axes.

Return type

Any