Differentiation

In this section, we are going to talk about how to make auto differentiation on your functions and class objects with ‘jax’ backend. In nowadays machine learning systems, computing and using gradients are common in various situations. So, we try to understand

  • how to calculate derivatives of arbitrary complex functions,

  • how to compute high-order gradients.

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

All autodiff functions in BrainPy support pure functions and class objects.

grad()

brainpy.math.jax.grad() takes a function/object and returns a new function which computes the gradient of the original function/object.

Pure functions

For pure function, the gradient is taken with respect to the first argument:

def f(a, b):
    return a * 2 + b

grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32)

However, this can be controlled via the argnums argument.

grad_f2 = bm.grad(f, argnums=(0, 1))

grad_f2(2., 1.)
(DeviceArray(2., dtype=float32), DeviceArray(1., dtype=float32))

Class objects

For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars argument:

class F(bp.Base):
    def __init__(self):
        super(F, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        vv = ab2 + c
        return vv.mean()
    
f = F()

The grad_vars can be a JaxArray, or a list/tuple/dict of JaxArray.

bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': TrainVar(DeviceArray([2.], dtype=float32)),
 'F0.b': TrainVar(DeviceArray([2.], dtype=float32))}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
[TrainVar(DeviceArray([2.], dtype=float32)),
 TrainVar(DeviceArray([2.], dtype=float32))]

If there are values dynamically changed in the gradient function, you can provide them in the dyn_vars argument.

class F2(bp.Base):
    def __init__(self):
        super(F2, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab = ab * 2
        self.a.value = ab
        return (ab + c).mean()
f2 = F2()
bm.grad(f2, dyn_vars=[f2.a], grad_vars=f2.b)(10.)
TrainVar(DeviceArray([2.], dtype=float32))

Also, if you are interested with the gradient of the input value, please use argnums argument. For this situation, calling the gradient function will return (grads_of_grad_vars, *grads_of_args).

class F3(bp.Base):
    def __init__(self):
        super(F3, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c, d):
        ab = self.a * self.b
        ab = ab * 2
        return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_arg0 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grads_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_arg0, grad_of_arg1 = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)

print("grads_of_gv :", grads_of_gv)
print("grad_of_arg0 :", grad_of_arg0)
print("grad_of_arg1 :", grad_of_arg1)
grads_of_gv : [TrainVar(DeviceArray([2.], dtype=float32)), TrainVar(DeviceArray([2.], dtype=float32))]
grad_of_arg0 : 3.0
grad_of_arg1 : 10.0

Actually, we recommend you to provide any dynamically changed variables (no matter them are updated in the gradient function) in the dyn_vars argument.

Auxiliary data

Usually, we want to get the value of the loss, or, we want to return some intermediate variables during the gradient computation. For them situation, users can set has_aux=True to return auxiliary data, and set return_value=True to return loss value.

# return loss

grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)

print('grad: ', grad)
print('loss: ', loss)
grad:  TrainVar(DeviceArray([2.], dtype=float32))
loss:  12.0
class F4(bp.Base):
    def __init__(self):
        super(F4, self).__init__()
        self.a = bm.TrainVar(bm.ones(1))
        self.b = bm.TrainVar(bm.ones(1))

    def __call__(self, c):
        ab = self.a * self.b
        ab2 = ab * 2
        loss = (ab + c).mean()
        return loss, (ab, ab2)
    

f4 = F4()
    
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)

print('grad: ', grad)
print('aux_data: ', aux_data)
grad:  TrainVar(DeviceArray([1.], dtype=float32))
aux_data:  (JaxArray(DeviceArray([1.], dtype=float32)), JaxArray(DeviceArray([2.], dtype=float32)))

Note: Any function wants to compute gradients through brainpy.math.jax.grad() must return a scalar value. Otherwise an error will raise.

try:
    bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
    print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output was [0. 0.].
# this is right
bm.grad(lambda x: x.mean())(bm.zeros(2))
JaxArray(DeviceArray([0.5, 0.5], dtype=float32))

If you want to take gradients for a vector-output values, please use brainpy.math.jax.jacobian() function.

jacobian()

Coming soon.

hessian()

Coming soon.