jax_parallelize_map#
- class brainpy.running.jax_parallelize_map(func, arguments, num_parallel, clear_buffer=False)[source]#
Perform a parallelized map of a function by using
jax.pmap
.This function can be used in multi- CPU or GPU backends. If you are using it in a single CPU, please set host device count by
brainpy.math.set_host_device_count(n)
before.- Parameters:
- Returns:
results – The running results.
- Return type:
Any