jax_vectorize_map

jax_vectorize_map#

class brainpy.running.jax_vectorize_map(func, arguments, num_parallel, clear_buffer=False)[source]#

Perform a vectorized map of a function by using jax.vmap.

This function can be used in CPU or GPU backends. But it is highly suitable to be used in GPU backends. This is because jax.vmap can parallelize the mapped axis on GPU devices.

Parameters:
Returns:

results – The running results.

Return type:

Any