jax_parallelize_map

jax_parallelize_map#

class brainpy.running.jax_parallelize_map(func, arguments, num_parallel, clear_buffer=False)[source]#

Perform a parallelized map of a function by using jax.pmap.

This function can be used in multi- CPU or GPU backends. If you are using it in a single CPU, please set host device count by brainpy.math.set_host_device_count(n) before.

Parameters:
  • func (callable, function) – The function to be mapped.

  • arguments (sequence, dict) – The function arguments, used to define tasks.

  • num_parallel (int) – The number of batch size.

  • clear_buffer (bool) – Clear the buffer memory after running each batch data.

Returns:

results – The running results.

Return type:

Any