Random Number Generation for JIT Compilation#

Chaoming Wang

Although brainpy.math.random is designed to be seamlessly compatible with numpy.random, there are still some differences under the context of JIT compilation.

In this section, we are going to talk about how to program a JIT-compatible code with brainpy.math.random.

import brainpy as bp
import brainpy.math as bm
import numpy as np

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Using bm.random outside functions to JIT#

Using brainpy.math.random outside of functions to JIT is the same as using numpy.random.

This usage corresponds to the cases that generating random data for further processing. For example,

bm.random.rand(10)
Array([0.7222161 , 0.2043277 , 0.59838593, 0.255252  , 0.14954388,
       0.05150986, 0.214692  , 0.03857851, 0.81150043, 0.4669956 ],      dtype=float32)
np.random.rand(10)
array([0.60626677, 0.69464463, 0.81361761, 0.1583908 , 0.50378113,
       0.17677626, 0.7507633 , 0.75699064, 0.33320096, 0.38958635])

When you are using API functions in brainpy.math.random, actually you are calling functions in a default RandomState. Similarly, numpy.random also has a default RandomState. Calling a random function in numpy.random module corresponds to calling the random function in this default NumPy RandomState.

bm.random.DEFAULT
RandomState(key=([3014124236, 2009892527], dtype=uint32))

Using bm.random inside a function to JIT#

If you are using random sampling in a JIT function, there are things you need to pay attention to. Otherwise, the error is likely to raise.

As I have stated above, brainpy.math.random functions are using the default RandomState. A RandomState is an instance of brainpy Variable, denoting that it has values to change after calling any its built-in random function. What’s changing is the key of a RandomState. For instance,

bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([ 873106783, 4065854088], dtype=uint32))
bm.random.rand(1)
print('Now, the DEFAULT is', bm.random.DEFAULT)
Now, the DEFAULT is RandomState(key=([3526960574,  230845945], dtype=uint32))

Therefore, if you do not specify this DEFAULT RandomState you are using, repeatedly calling random functions in brainpy.math.random module will not get what you want, because its key cannot be updated. For instance,

@bm.jit
def get_data():
    return bm.random.random(2)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)
get_data()
Array([0.80141556, 0.19009137], dtype=float32)

A correct way is explicitly declaring you are using this DEFAULT variable in the JIT transformation.

bm.random.seed()
from functools import partial

@partial(bm.jit, dyn_vars=(bm.random.DEFAULT, ))
def get_data_v2():
    return bm.random.random(2)
get_data_v2()
Array([0.38541543, 0.5843446 ], dtype=float32)
get_data_v2()
Array([0.85543776, 0.36957836], dtype=float32)

Or, declare the function as a BrainPyObject, then use jit().

@bm.jit
@bm.to_object(dyn_vars=bm.random.DEFAULT)
def get_data_v3():
    return bm.random.random(2)
get_data_v3()
Array([0.31096482, 0.7970413 ], dtype=float32)
get_data_v3()
Array([0.26830554, 0.15947664], dtype=float32)

Using RandomState for objects to JIT#

Another way I recommend is using instances of RandomState for objects to JIT. For example, you can initialize a RandomState in the __init__() function, then using the initialized RandomState anywhere.

class MyOb(bp.BrainPyObject):
    def __init__(self):
        super().__init__()
        self.rng = bm.random.RandomState(123)

    def __call__(self):
        size = (50, 100)
        u = self.rng.random(size)
        v = self.rng.uniform(size=size)
        z = bm.sqrt(-2 * bm.log(u)) * bm.cos(2 * bm.pi * v)
        return z
ob = bm.jit(MyOb())
ob()
Array([[ 1.3595979 , -1.3462192 ,  0.7149456 , ...,  1.4283268 ,
        -1.1362855 , -0.18378317],
       [-0.26401126, -1.6798397 , -0.8422355 , ...,  1.0795223 ,
         0.41247413, -0.955116  ],
       [ 0.6234829 , -0.44811824, -0.03835859, ..., -2.5203867 ,
        -0.02713326,  1.6490041 ],
       ...,
       [-0.9861029 ,  0.36676335, -0.31499916, ...,  1.526808  ,
        -0.7946268 , -0.86713606],
       [-1.7008592 , -0.05957834, -0.5677447 , ..., -0.04765594,
         0.574145  , -0.11830498],
       [-0.22663854, -1.8517947 , -1.3546717 , ...,  1.2332705 ,
        -0.79247886, -1.9352005 ]], dtype=float32)

Note that any Variable instance which can be directly accessed by self. is able to be automatically found by brainpy’s JIT transformation functions. Therefore, in this case, we do not need to pass the rng into the dyn_vars in bm.jit() function.