Concept 1: Object-oriented Transformation#

@Chaoming Wang

Most computation in BrainPy relies on JAX. JAX has provided wonderful transformations, including differentiation, vecterization, parallelization and just-in-time compilation, for Python programs. If you are not familiar with it, please see its documentation.

However, JAX only supports functional programming, i.e., transformations for Python functions. This is not what we want. Brain Dynamics Modeling need object-oriented programming.

To meet this requirement, BrainPy defines the interface for object-oriented (OO) transformations. These OO transformations can be easily performed for BrainPy objects.

In this section, let’s talk about the BrainPy concept of object-oriented transformations.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')
bp.__version__
'2.3.0'

Illustrating example: Training a network#

To illustrate this concept, we need a demonstration example. Here, we choose the popular neural network training as the illustrating case.

In this training case, we want to teach the neural network to correctly classify a random array as two labels (True or False). That is, we have the training data:

num_in = 100
num_sample = 256
X = bm.random.rand(num_sample, num_in)
Y = (bm.random.rand(num_sample) < 0.5).astype(float)

We use a two-layer feedforward network:

class Linear(bp.BrainPyObject):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.num_in = n_in
        self.num_out = n_out
        init = bp.init.XavierNormal()
        self.W = bm.Variable(init((n_in, n_out)))
        self.b = bm.Variable(bm.zeros((1, n_out)))

    def __call__(self, x):
        return x @ self.W + self.b


net = bp.Sequential(Linear(num_in, 20),
                    bm.relu,
                    Linear(20, 2))
print(net)
Sequential(
  [0] Linear0
  [1] relu
  [2] Linear1
)

Here, we use a supervised learning training paradigm.

rng = bm.random.RandomState(123)


# Loss function
@bm.to_object(child_objs=net, dyn_vars=rng)
def loss():
    # shuffle the data
    key = rng.split_key()
    x_data = rng.permutation(X, key=key)
    y_data = rng.permutation(Y, key=key)
    # prediction
    predictions = net(dict(), x_data)
    # loss
    l = bp.losses.cross_entropy_loss(predictions, y_data)
    return l


# Gradient function
grad = bm.grad(loss, grad_vars=net.vars(), return_value=True)

# Optimizer
optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())


# Training step
@bm.to_object(child_objs=(grad, optimizer))
def train(i):
    grads, l = grad()
    optimizer.update(grads)
    return l


num_step = 400
for i in range(0, 4000, num_step):
    # train 400 steps once
    ls = bm.for_loop(train, operands=bm.arange(i, i + num_step))
    print(f'Train {i + num_step} epoch, loss = {bm.mean(ls):.4f}')
Train 400 epoch, loss = 0.6710
Train 800 epoch, loss = 0.5992
Train 1200 epoch, loss = 0.5332
Train 1600 epoch, loss = 0.4720
Train 2000 epoch, loss = 0.4189
Train 2400 epoch, loss = 0.3736
Train 2800 epoch, loss = 0.3335
Train 3200 epoch, loss = 0.2972
Train 3600 epoch, loss = 0.2644
Train 4000 epoch, loss = 0.2346

In the above example, we have seen classical elements in a neural network training, such as

  • net: neural network

  • loss: loss function

  • grad: gradient function

  • optimizer: parameter optimizer

  • train: training step

In BrainPy, all these elements can be defined as class objects and can be used for performing OO transformations.

In essence, the concept of BrainPy object-oriented transformation has three components:

  • BrainPyObject: the base class for object-oriented programming

  • Variable: the varibles in the class object, whose values are ready to be changed/updated during transformation

  • ObjectTransform: the transformations for computation involving BrainPyObject and Variable

BrainPyObject and its Variable#

BrainPyObject is the base class for object-oriented programming in BrainPy. It can be viewed as a container which contains all needed Variable for our computation.

In the above example, Linear object has two Variable: W and b. The net we defined is further composed of two Linear objects. We can expect that four variables can be retrieved from it.

net.vars().keys()
dict_keys(['Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])

An important question is, how to define Variable in a BrainPyObject so that we can retrieve all of them?

Actually, all Variable instance which can be accessed by self. attribue can be retrived from a BrainPyObject recursively. No matter how deep the composition of BrainPyObject, once BrainPyObject instance and their Variable instances can be accessed by self. operation, all of them will be retrieved.

class SuperLinear(bp.BrainPyObject):
    def __init__(self, ):
        super().__init__()
        self.l1 = Linear(10, 20)
        self.v1 = bm.Variable(3)
        
sl = SuperLinear()
# retrieve Variable
sl.vars().keys()
dict_keys(['SuperLinear0.v1', 'Linear2.W', 'Linear2.b'])
# retrieve BrainPyObject
sl.nodes().keys()
dict_keys(['SuperLinear0', 'Linear2'])

However, we cannot access the BrainPyObject or Variable which is in a Python container (like tuple, list, or dict). For this case, we can register our objects and variables through .register_implicit_vars() and .register_implicit_nodes():

class SuperSuperLinear(bp.BrainPyObject):
    def __init__(self, register=False):
        super().__init__()
        self.ss = [SuperLinear(), SuperLinear()]
        self.vv = {'v_a': bm.Variable(3)}
        if register:
            self.register_implicit_nodes(self.ss)
            self.register_implicit_vars(self.vv)
# without register
ssl = SuperSuperLinear(register=False)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys([])
dict_keys(['SuperSuperLinear0'])
# with register
ssl = SuperSuperLinear(register=True)
print(ssl.vars().keys())
print(ssl.nodes().keys())
dict_keys(['SuperSuperLinear1.v_a', 'SuperLinear3.v1', 'SuperLinear4.v1', 'Linear5.W', 'Linear5.b', 'Linear6.W', 'Linear6.b'])
dict_keys(['SuperSuperLinear1', 'SuperLinear3', 'SuperLinear4', 'Linear5', 'Linear6'])

Transform a function to BrainPyObject#

Let’s go back to our network training. After the definition of net, we further define a loss function whose computation involves the net object for neural network prediction and a rng Variable for data shuffling.

This Python function is then transformed into a BrainPyObject instance by brainpy.math.to_object interface.

loss
FunAsObject(nodes=[Sequential0],
            num_of_vars=1)

All Variable used in this instance can also be retrieved through:

loss.vars().keys()
dict_keys(['loss0._var0', 'Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])

Note that, when using to_object(), we need to explicitly declare all BrainPyObject and Variable used in this Python function. Due to the recursive retrieval property of BrainPyObject, we only need to specify the latest composition object.

In the above loss object, we do not need to specify two Linear object. Instead, we only need to give the top level object net into to_object() transform.

Similarly, when we transform train function into a BrainPyObject, we just need to point out the grad and opt we have used, rather than the previous loss, net or rng.

BrainPy object-oriented transformations#

BrainPy object-oriented transformations are designed to work on BrainPyObject. These transformations include autograd brainpy.math.grad() and JIT brainpy.math.jit().

In our case, we used two OO transformations provided in BrainPy.

First, grad object is defined with the loss function. Within it, we need to specify what variables we need to compute their gradients through grad_vars.

Note that, the OO transformation of any BrainPyObject results in another BrainPyObject object. Therefore, it can be recersively used as a component to form the larger scope of object-oriented programming and object-oriented transformation.

grad
GradientTransform(target=loss0, 
                  num_of_grad_vars=4, 
                  num_of_dyn_vars=1)

Next, we train 400 steps once by using a for_loop transformation. Different from grad which return a BrainPyObject instance, for_loop direactly returns the loop results.