Source code for brainpy._src.running.native_multiprocessing
# -*- coding: utf-8 -*-
from typing import Union, Sequence, Dict
import multiprocessing
__all__ = [
'process_pool',
'process_pool_lock',
]
[docs]def process_pool(func: callable,
all_params: Union[Sequence, Dict],
num_process: int):
"""Run multiple models in multi-processes.
.. Note::
This multiprocessing function should be called within a `if __main__ == '__main__':` syntax.
Parameters
----------
func : callable
The function to run model.
all_params : list, tuple, dict
The parameters of the function arguments.
The parameters for each process can be a tuple, or a dictionary.
num_process : int
The number of the processes.
Returns
-------
results : list
Process results.
"""
print('{} jobs total.'.format(len(all_params)))
pool = multiprocessing.Pool(processes=num_process)
results = []
for params in all_params:
if isinstance(params, (list, tuple)):
results.append(pool.apply_async(func, args=tuple(params)))
elif isinstance(params, dict):
results.append(pool.apply_async(func, kwds=params))
else:
raise ValueError('Unknown parameter type: ', type(params))
pool.close()
pool.join()
return [r.get() for r in results]
[docs]def process_pool_lock(func: callable,
all_params: Union[Sequence, Dict],
num_process: int):
"""Run multiple models in multi-processes with lock.
Sometimes, you want to synchronize the processes. For example,
if you want to write something in a document, you cannot let
multiprocess simultaneously open this same file. So, you need
add a `lock` argument in your defined `func`:
.. code-block:: python
def some_func(..., lock, ...):
... do something ..
lock.acquire()
... something cannot simultaneously do by multi-process ..
lock.release()
In such case, you can use `process_pool_lock()` to run your model.
.. Note::
This multiprocessing function should be called within a `if __main__ == '__main__':` syntax.
Parameters
----------
func: callable
The function to run model.
all_params : list, tuple, dict
The parameters of the function arguments.
num_process : int
The number of the processes.
Returns
-------
results : list
Process results.
"""
print('{} jobs total.'.format(len(all_params)))
pool = multiprocessing.Pool(processes=num_process)
m = multiprocessing.Manager()
lock = m.Lock()
results = []
for net_params in all_params:
if isinstance(net_params, (list, tuple)):
results.append(pool.apply_async(func, args=tuple(net_params) + (lock,)))
elif isinstance(net_params, dict):
net_params.update(lock=lock)
results.append(pool.apply_async(func, kwds=net_params))
else:
raise ValueError('Unknown parameter type: ', type(net_params))
pool.close()
pool.join()
return [r.get() for r in results]