jax_vectorize_map

jax_vectorize_map#

class brainpy.running.jax_vectorize_map(func, arguments, num_parallel, clear_buffer=False)[source]#

Perform a vectorized map of a function by using jax.vmap.

This function can be used in CPU or GPU backends. But it is highly suitable to be used in GPU backends. This is because jax.vmap can parallelize the mapped axis on GPU devices.

Parameters:
  • func (callable, function) – The function to be mapped.

  • arguments (sequence, dict) – The function arguments, used to define tasks.

  • num_parallel (int) – The number of batch size.

  • clear_buffer (bool) – Clear the buffer memory after running each batch data.

Returns:

results – The running results.

Return type:

Any