Autograd for Class Variables#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about how to realize automatic differentiation on your variables in a function or a class object. In current machine learning systems, gradients are commonly used in various situations. Therefore, we should understand:

  • How to calculate derivatives of arbitrary complex functions?

  • How to compute high-order gradients?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Preliminary#

Every autograd function in BrainPy has several keywords. All examples below are illustrated through brainpy.math.grad(). Other autograd functions have the same settings.

argnums and grad_vars#

The autograd functions in BrainPy can compute derivatives of function arguments (specified by argnums) or non-argument variables (specified by grad_vars). For instance, the following is a linear readout model:

class Linear(bp.Base):
    def __init__(self):
        super(Linear, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        r = bm.dot(self.w, x) + self.b
        return r.sum()
    
l = Linear()

If we try to focus on the derivative of the argument “x” when calling the update function, we can set this through argnums:

grad = bm.grad(l.update, argnums=0)

grad(bm.ones(10))
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)

By contrast, if you focus on the derivatives of parameters “self.w” and “self.b”, we should label them with grad_vars:

grad = bm.grad(l.update, grad_vars=(l.w, l.b))

grad(bm.ones(10))
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
 DeviceArray([1.], dtype=float32))

If we pay attention to the derivatives of both argument “x” and parameters “self.w” and “self.b”, argnums and grad_vars can be used together. In this condition, the gradient function will return gradients with the format of (var_grads, arg_grads), where arg_grads refers to the gradients of “argnums” and var_grads refers to the gradients of “grad_vars”.

grad = bm.grad(l.update, grad_vars=(l.w, l.b), argnums=0)

var_grads, arg_grads = grad(bm.ones(10))
var_grads
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
 DeviceArray([1.], dtype=float32))
arg_grads
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)

return_value#

As is mentioned above, autograd functions return a function which computes gradients regardless of the returned value. Sometimes, however, we care about the value the function returns, not just the gradients. In this condition, you can set return_value=True in the autograd function.

grad = bm.grad(l.update, argnums=0, return_value=True)

gradient, value = grad(bm.ones(10))
gradient
JaxArray([0.9865978 , 0.14363837, 0.03861248, 0.42379665, 0.7038013 ,
          0.11866355, 0.67538667, 0.15790391, 0.6050298 , 0.778468  ],            dtype=float32)
value
DeviceArray(4.6318984, dtype=float32)

has_aux#

In some situations, we are interested in the intermediate values in a function, and has_aux=True can be of great help. The constraint is that you must return values with the format of (loss, aux_data). For instance,

class LinearAux(bp.Base):
    def __init__(self):
        super(LinearAux, self).__init__()
        self.w = bm.random.random((1, 10))
        self.b = bm.zeros(1)
    
    def update(self, x):
        dot = bm.dot(self.w, x)
        r = (dot + self.b).sum()
        return r, (r, dot)  # here the aux data is a tuple, includes the loss and the dot value.
                            # however, aux can be arbitrary complex.
    
l2 = LinearAux()
grad = bm.grad(l2.update, argnums=0, has_aux=True)

gradient, aux = grad(bm.ones(10))
gradient
JaxArray([0.20289445, 0.4745227 , 0.36053288, 0.94524395, 0.8360598 ,
          0.06507981, 0.7748591 , 0.8377187 , 0.5767547 , 0.47604012],            dtype=float32)
aux
(DeviceArray(5.5497055, dtype=float32), JaxArray([5.5497055], dtype=float32))

When multiple keywords (argnums, grad_vars, has_aux orreturn_value) are set simulatenously, the return format of the gradient function can be inspected through the corresponding API documentation brainpy.math.grad().

brainpy.math.grad()#

brainpy.math.grad() takes a function/object (\(f : \mathbb{R}^n \to \mathbb{R}\)) as the input and returns a new function (\(\partial f(x) \to \mathbb{R}^n\)) which computes the gradient of the original function/object. It’s worthy to note that brainpy.math.grad() only supports returning scalar values.

Pure functions#

For pure function, the gradient is taken with respect to the first argument:

def f(a, b):
    return a * 2 + b

grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32, weak_type=True)

However, this can be controlled via the argnums argument.

grad_f2 = bm.grad(f, argnums=(0, 1))

grad_f2(2., 1.)
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(1., dtype=float32, weak_type=True))

Class objects#

For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars and argnums setting:

class F(bp.Base):
    def __init__(self):
        super(F, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        vv = ab2 + c
        return vv.mean()
    
f = F()

The grad_vars can be a JaxArray, or a list/tuple/dict of JaxArray.

bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': DeviceArray([2.], dtype=float32),
 'F0.b': DeviceArray([2.], dtype=float32)}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
(DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))

If there are dynamically changed values in the gradient function, you can provide them in the dyn_vars argument.

class F2(bp.Base):
    def __init__(self):
        super(F2, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab = ab * 2
        self.a.value = ab
        return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
DeviceArray([2.], dtype=float32)

Besides, if you are interested in the gradient of the input value, please use the argnums argument. Then, the gradient function will return (grads_of_grad_vars, grads_of_args).

class F3(bp.Base):
    def __init__(self):
        super(F3, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c, d):
        ab = self.a * self.b
        ab = ab * 2
        return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : (DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32))
grad_of_args : (DeviceArray(3., dtype=float32, weak_type=True), DeviceArray(10., dtype=float32, weak_type=True))

Actually, it is recommended to provide all dynamically changed variables, whether or not they are updated in the gradient function, in the dyn_vars argument.

Auxiliary data#

Usually, we want to get the loss value, or we want to return some intermediate variables during the gradient computation. In these situation, users can set has_aux=True to return auxiliary data and set return_value=True to return the loss value.

# return loss

grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)

print('grad: ', grad)
print('loss: ', loss)
grad:  [2.]
loss:  12.0
class F4(bp.Base):
    def __init__(self):
        super(F4, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        loss = (ab + c).mean()
        return loss, (ab, ab2)
    

f4 = F4()
    
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)

print('grad: ', grad)
print('aux_data: ', aux_data)
grad:  [1.]
aux_data:  (JaxArray([1.], dtype=float32), JaxArray([2.], dtype=float32))
Any function used to compute gradients through ``brainpy.math.grad()`` must return a scalar value. Otherwise an error will raise. 
try:
    bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
    print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output had shape: (2,).
# this is right

bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray([0.5, 0.5], dtype=float32)

brainpy.math.vector_grad()#

If users want to take gradients for a vector-output values, please use the brainpy.math.vector_grad() function. For example,

def f(a, b): 
    return bm.sin(b) * a

Gradients for vectors#

# vectors

a = bm.arange(5.)
b = bm.random.random(5)
bm.vector_grad(f)(a, b)
JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ],            dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([0.22263631, 0.19832121, 0.47522876, 0.40596786, 0.2040254 ],            dtype=float32),
 JaxArray([0.       , 0.9801371, 1.7597246, 2.741662 , 3.9158623], dtype=float32))

Gradients for matrices#

# matrix

a = bm.arange(6.).reshape((2, 3))
b = bm.random.random((2, 3))
bm.vector_grad(f, argnums=1)(a, b)
JaxArray([[0.       , 0.8662993, 1.1221857],
          [2.9322515, 2.3293345, 3.024507 ]], dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(JaxArray([[0.45055482, 0.49952534, 0.8277529 ],
           [0.21131878, 0.8129499 , 0.79630035]], dtype=float32),
 JaxArray([[0.       , 0.8662993, 1.1221857],
           [2.9322515, 2.3293345, 3.024507 ]], dtype=float32))

Similar to brainpy.math.grad() , brainpy.math.vector_grad() also supports derivatives of variables in a class object. Here is a simple example.

class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.ones(5)
    self.y = bm.ones(5)

  def __call__(self):
    return self.x ** 2 + self.y ** 3 + 10

t = Test()
bm.vector_grad(t, grad_vars=t.x)()
DeviceArray([2., 2., 2., 2., 2.], dtype=float32)
bm.vector_grad(t, grad_vars=(t.x, ))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),)
bm.vector_grad(t, grad_vars=(t.x, t.y))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),
 DeviceArray([3., 3., 3., 3., 3.], dtype=float32))

Other operations like return_value and has_aux in brainpy.math.vector_grad() are the same as those in brainpy.math.grad() .

brainpy.math.jacobian()#

Another way to take gradients of a vector-output value is using brainpy.math.jacobian(). brainpy.math.jacobian() aims to automatically compute the Jacobian matrices \(\partial f(x) \in \mathbb{R}^{m \times n}\) by the given function \(f : \mathbb{R}^n \to \mathbb{R}^m\) at the given point of \(x \in \mathbb{R}^n\). Here, we will not go to the details of the implementation and usage of the brainpy.math.jacobian(). Instead, we only show two examples about the pure function and class function.

Given the following function,

import jax.numpy as jnp

def f1(x, y):
    a = 4 * x[1] ** 2 - 2 * x[2]
    r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
    return r, a
_x = bm.array([1., 2., 3.])
_y = bm.array([10., 5.])
    
grads, vec, aux = bm.jacobian(f1, return_value=True, has_aux=True)(_x, _y)
grads
JaxArray([[10.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        , 25.        ],
          [ 0.        , 16.        , -2.        ],
          [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32)
vec
DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)
aux
DeviceArray(10., dtype=float32)

Given the following class objects,

class Test(bp.Base):
  def __init__(self):
    super(Test, self).__init__()
    self.x = bm.array([1., 2., 3.])

  def __call__(self, y):
    a = self.x[0] * y[0]
    b = 5 * self.x[2] * y[1]
    c = 4 * self.x[1] ** 2 - 2 * self.x[2]
    d = self.x[2] * jnp.sin(self.x[0])
    r = jnp.asarray([a, b, c, d])
    return r, (c, d)
t = Test()
f_grad = bm.jacobian(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)

(var_grads, arg_grads), value, aux = f_grad(_y)
var_grads
DeviceArray([[10.        ,  0.        ,  0.        ],
             [ 0.        ,  0.        , 25.        ],
             [ 0.        , 16.        , -2.        ],
             [ 1.6209068 ,  0.        ,  0.84147096]], dtype=float32)
arg_grads
JaxArray([[ 1.,  0.],
          [ 0., 15.],
          [ 0.,  0.],
          [ 0.,  0.]], dtype=float32)
value
DeviceArray([10.       , 75.       , 10.       ,  2.5244129], dtype=float32)
aux
(DeviceArray(10., dtype=float32), DeviceArray(2.5244129, dtype=float32))

For more details on automatical differentation, please see our API documentation.