Source code for brainpy._src.running.jax_multiprocessing

# -*- coding: utf-8 -*-

from typing import Sequence, Dict, Union

import numpy as np
from jax import vmap, pmap
from jax.tree_util import tree_unflatten, tree_flatten

import brainpy.math as bm
from brainpy.types import ArrayType

__all__ = [
  'jax_vectorize_map',
  'jax_parallelize_map',
]


[docs] def jax_vectorize_map( func: callable, arguments: Union[Dict[str, ArrayType], Sequence[ArrayType]], num_parallel: int, clear_buffer: bool = False ): """Perform a vectorized map of a function by using ``jax.vmap``. This function can be used in CPU or GPU backends. But it is highly suitable to be used in GPU backends. This is because ``jax.vmap`` can parallelize the mapped axis on GPU devices. Parameters ---------- func: callable, function The function to be mapped. arguments: sequence, dict The function arguments, used to define tasks. num_parallel: int The number of batch size. clear_buffer: bool Clear the buffer memory after running each batch data. Returns ------- results: Any The running results. """ if not isinstance(arguments, (dict, tuple, list)): raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.Array)) if clear_buffer: elements = [np.asarray(ele) for ele in elements] num_pars = [len(ele) for ele in elements] if len(np.unique(num_pars)) != 1: raise ValueError(f'All elements in parameters should have the same length. ' f'But we got {tree_unflatten(tree, num_pars)}') res_tree = None results = None vmap_func = vmap(func) for i in range(0, num_pars[0], num_parallel): run_f = vmap(func) if clear_buffer else vmap_func if isinstance(arguments, dict): r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) elif isinstance(arguments, (tuple, list)): r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) else: raise TypeError res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array)) if results is None: results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) else: for j, val in enumerate(res_values): results[j].append(np.asarray(val) if clear_buffer else val) if clear_buffer: bm.clear_buffer_memory() if res_tree is None: return None results = ([np.concatenate(res, axis=0) for res in results] if clear_buffer else [bm.concatenate(res, axis=0) for res in results]) return tree_unflatten(res_tree, results)
[docs] def jax_parallelize_map( func: callable, arguments: Union[Dict[str, ArrayType], Sequence[ArrayType]], num_parallel: int, clear_buffer: bool = False ): """Perform a parallelized map of a function by using ``jax.pmap``. This function can be used in multi- CPU or GPU backends. If you are using it in a single CPU, please set host device count by ``brainpy.math.set_host_device_count(n)`` before. Parameters ---------- func: callable, function The function to be mapped. arguments: sequence, dict The function arguments, used to define tasks. num_parallel: int The number of batch size. clear_buffer: bool Clear the buffer memory after running each batch data. Returns ------- results: Any The running results. """ if not isinstance(arguments, (dict, tuple, list)): raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.Array)) if clear_buffer: elements = [np.asarray(ele) for ele in elements] num_pars = [len(ele) for ele in elements] if len(np.unique(num_pars)) != 1: raise ValueError(f'All elements in parameters should have the same length. ' f'But we got {tree_unflatten(tree, num_pars)}') res_tree = None results = None vmap_func = pmap(func) for i in range(0, num_pars[0], num_parallel): run_f = pmap(func) if clear_buffer else vmap_func if isinstance(arguments, dict): r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) else: r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array)) if results is None: results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) else: for j, val in enumerate(res_values): results[j].append(np.asarray(val) if clear_buffer else val) if clear_buffer: bm.clear_buffer_memory() if res_tree is None: return None results = ([np.concatenate(res, axis=0) for res in results] if clear_buffer else [bm.concatenate(res, axis=0) for res in results]) return tree_unflatten(res_tree, results)