jax_parallelize_map#
- class brainpy.running.jax_parallelize_map(func, arguments, num_parallel, clear_buffer=False)[source]#
Perform a parallelized map of a function by using
jax.pmap.This function can be used in multi- CPU or GPU backends. If you are using it in a single CPU, please set host device count by
brainpy.math.set_host_device_count(n)before.- Parameters:
func (
callable) – The function to be mapped.arguments (
Union[Dict[str,TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray)],Sequence[TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray)]]) – The function arguments, used to define tasks.num_parallel (
int) – The number of batch size.clear_buffer (
bool) – Clear the buffer memory after running each batch data.
- Returns:
results – The running results.
- Return type:
Any