jax_parallelize_map

jax_parallelize_map#

class brainpy.running.jax_parallelize_map(func, arguments, num_parallel, clear_buffer=False)[source]#

Perform a parallelized map of a function by using jax.pmap.

This function can be used in multi- CPU or GPU backends. If you are using it in a single CPU, please set host device count by brainpy.math.set_host_device_count(n) before.

Parameters:
Returns:

results – The running results.

Return type:

Any