Parallel Simulation for Parameter Exploration#

Colab Open in Kaggle

@Tianqiu Zhang @Chaoming Wang

Parameter exploration and selection is an essential part in brain dynamics modeling. In general, there are two problems for the parameter exploration:

  1. how to run multiple models concurrently?

  2. how to manage device memory allowing multiple models to run concurrently?

First, most of the BrainPy models supports multiple kinds of parallelization, including parallelization of multi-threading and multi-processing on a single machine, and parallelization across multiple devices. In the below, we will illustrate these parallelization APIs one-by-one.

Second, every call of a BrainPy model will consume a fraction of device memory. Therefore, BrainPy provides a API brainpy.math.clear_buffer_memory() for memory clean.

In the following, we will illustrate how to combine them together to get an efficient parameter exploration for your models.

import brainpy as bp
import brainpy.math as bm
import numpy as np

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Parallelization across different CPU processors#

Parallelization across multiple CPU processors can be easily achieved with a single line of functional call brainpy.running.cpu_ordered_parallel(). The following pseudocode demonstrates the usage of this API.

import brainpy as bp

# define your function
def run_model(par):
  model = YourModel(par)
  runner = bp.DSRunner(model)
  runner.run(duration)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# run models in Jupyter
results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)

# run models in python file
if __name__ == '__main__':
  results = bp.running.cpu_ordered_parallel(run_model, all_params, num_process=10)

We will use a simple HH neuron model as an example to show this kind of parallelization method. In this example, we use multi-processing technique to test four different current values as input.

First, define your running function with the well-defined input and output data.

def hh_spike_num(bg_current): # "input" is the bg_current
  import brainpy as bp  # needed to reimport packages when
                        # run the function in Jupyter
  model = bp.neurons.HH(1)
  runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
  runner.run(1000.)
  return runner.mon['spike'].sum()  # "output" is the spike number

Then, define all your parameter spaces.

current = bm.linspace(1, 10.1, 10)  # here only one parameter

Finally, run your model concurrently with the parallelization syntax.

r = bp.running.cpu_ordered_parallel(hh_spike_num, [current], num_process=10)

r
[0, 0, 1, 48, 53, 0, 54, 63, 66, 68]

However, the above usage will accumulate buffer memory in the running device. If your single model occupies too much memory, the out-of-memory error will be raised during the parameter exploration.

A simple way to solve this issue is clear all buffers after each running of the function. For example, before returning your results, call brainpy.math.clear_buffer_memory() first.

def hh_spike_num2(bg_current): # "input" is the bg_current
  import brainpy as bp  # needed to reimport packages when
                        # run the function in Jupyter

  bg_current = bp.math.as_jax(bg_current)
  model = bp.neurons.HH(1)
  runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])
  runner.run(1000.)

  bp.math.clear_buffer_memory()
  return runner.mon['spike'].sum()  # "output" is the spike number

Note that clear_buffer_memory() will clear all JAX arrays in the device, therefore, it’s better to give inputs as NumPy arrays, and return outputs as NumPy arrays.

current = np.linspace(1., 10., 10)

r = bp.running.cpu_ordered_parallel(hh_spike_num2, [current], num_process=10)
r
[0, 0, 1, 0, 0, 57, 60, 58, 65, 68]

If you think that the order of the running results does not matter, you can also use cpu_unordered_parallel() function. This can maximize the running efficiency of all processors, since all workers run with a non-blocking and unordered manner.

Parallelization with jax.vmap#

The second approach of realizing multi-threading parallelization is the vectorization map of JAX jax.vmap. jax.vmap vectorizes functions by compiling the mapped axis as primitive operations. It can avoid the recompilation of models in the same batch, and automatically parallelize the model running on the given machine. Following pseudocode demonstrates how simple of this parallelization approach is.

from jax import vmap

def run_model(par):
  model = YourModel(par)
  runner = bp.DSRunner(model)
  runner.run(duration)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# batch simulation through jax.vmap
r = vmap(run_model)(*all_params)

Note that if you have too many parameters to search, jax.vmap will consume too much memory. For this time, you can use our wrapped API brainpy.running.jax_vectorize_map(), which controls the running batch size by num_parallel parameter. You can set a smaller value of num_parallel when your device memory is not enough (no matter on the CPU or GPU device).

def hh_spike_num3(bg_current): # "input" is the bg_current
  model = bp.neurons.HH(1)
  runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current],
                           numpy_mon_after_run=False)
  runner.run(1000.)
  return runner.mon['spike'].sum()  # "output" is the spike number
current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3)
r
Array([ 0,  0,  0,  0,  0, 45, 60, 63, 66, 68], dtype=int32)

The function throw into the jax_vectorize_map() can not call clear_buffer_memory(). Otherwise will raise errors. Instead, uses can set clear_buffer=True/False using jax_vectorize_map(). For such kind of usage, all inputs and outputs will be automatically transformed in to NumPy arrays.

current = bm.linspace(1., 10.1, 10)
r = bp.running.jax_vectorize_map(hh_spike_num3, [current], num_parallel=3, clear_buffer=True)
r
array([ 0,  1,  1,  0,  0, 57, 60, 63, 66, 68])

Parallelization across multiple devices#

BrainPy support parallelization running on multiple devices (e.g., multiple GPU devices or TPU cores) or HPC systems (e.g., supercomputers). Different from the above thread-based and processor-based parallelization methods, in which the same model runs in parallel on the same device, device-based parallelization runs the same model in parallel on multiple devices.

One way to express the multi-device parallelization of BrainPy models is using jax.pmap instruction. JAX delivers jax.pmap to express SIMD programs. It provides an interface to run the same model on multiple devices with different parameter values. It usage is analogy to jax.vmap. Following pseudocode presents an example to run BrainPy models on multiple devices.

from jax import pmap

def run_model(par):
  model = YourModel(par)
  runner = bp.DSRunner(model)
  runner.run(<int>)
  return runner.mon

# define all parameter values need to explore
all_params = [...]

# parallel simulation through jax.pmap
r = pmap(run_model)(*all_params)

jax.pmap has the similar issue to jax.vmap when you parallelize across many parameters. This time you can use the wrapped function brainpy.running.jax_parallelize_map().

If you are using pmap in you CPU device, you can set the virtual number of the device by calling brainpy.math.set_host_device_count(n). Then, you can call jax_parallelize_map() safely one your CPU platform.

bp.math.set_host_device_count(10)  # this should place on the top of the file

current = bm.linspace(1., 10.1, 20)
r = bp.running.jax_parallelize_map(hh_spike_num3, [current], num_parallel=10, clear_buffer=True)
r
array([ 0,  0,  0,  0,  0,  0,  0, 49, 52, 54, 56, 58, 59, 61, 62, 63, 65,
       66, 67, 68])

BrainPy also works well with job scheduling systems such as SLURM on a supercomputer center. Therefore, another way to express multi-device parallelization is to employ the classical resource management system. Following script demonstrates an example that submits a batch script to SLURM.


#!/bin/bash
#SBATCH -J <name>
#SBATCH -o <file name>
#SBATCH -p <str>
#SBATCH -n <int>
#SBATCH -N <int>
#SBATCH -c <int>

python your_script.py