Automatic Differentiation with BrainPyObject
#
In this section, we are going to talk about how to realize automatic differentiation on your variables in a function or a class object. In current machine learning systems, gradients are commonly used in various situations. Therefore, we should understand:
How to calculate derivatives of arbitrary complex functions?
How to compute high-order gradients?
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
bp.__version__
'2.4.1'
Preliminary#
Every autograd function in BrainPy has several keywords. All examples below are illustrated through brainpy.math.grad(). Other autograd functions have the same settings.
argnums
and grad_vars
#
The autograd functions in BrainPy can compute derivatives of function arguments (specified by argnums
) or non-argument variables (specified by grad_vars
). For instance, the following is a linear readout model:
class Linear(bp.BrainPyObject):
def __init__(self):
super(Linear, self).__init__()
self.w = bm.Variable(bm.random.random((1, 10)))
self.b = bm.Variable(bm.zeros(1))
def update(self, x):
r = bm.dot(self.w, x) + self.b
return r.sum()
l = Linear()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
If we try to focus on the derivative of the argument “x” when calling the update function, we can set this through argnums
:
grad = bm.grad(l.update, argnums=0)
grad(bm.ones(10))
Array(value=DeviceArray([0.74814725, 0.16502357, 0.19869995, 0.9638033 , 0.7735306 ,
0.6862997 , 0.7359276 , 0.97442615, 0.2690258 , 0.02489543], dtype=float32),
dtype=float32)
By contrast, if you focus on the derivatives of parameters “self.w” and “self.b”, we should label them with grad_vars
:
grad = bm.grad(l.update, grad_vars=(l.w, l.b))
grad(bm.ones(10))
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
DeviceArray([1.], dtype=float32))
If we pay attention to the derivatives of both argument “x” and parameters “self.w” and “self.b”, argnums
and grad_vars
can be used together. In this condition, the gradient function will return gradients with the format of (var_grads, arg_grads)
, where arg_grads
refers to the gradients of “argnums” and var_grads
refers to the gradients of “grad_vars”.
grad = bm.grad(l.update, grad_vars=(l.w, l.b), argnums=0)
var_grads, arg_grads = grad(bm.ones(10))
var_grads
(DeviceArray([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32),
DeviceArray([1.], dtype=float32))
arg_grads
Array(value=DeviceArray([0.74814725, 0.16502357, 0.19869995, 0.9638033 , 0.7735306 ,
0.6862997 , 0.7359276 , 0.97442615, 0.2690258 , 0.02489543], dtype=float32),
dtype=float32)
return_value
#
As is mentioned above, autograd functions return a function which computes gradients regardless of the returned value. Sometimes, however, we care about the value the function returns, not just the gradients. In this condition, you can set return_value=True
in the autograd function.
grad = bm.grad(l.update, argnums=0, return_value=True)
gradient, value = grad(bm.ones(10))
gradient
Array(value=DeviceArray([0.74814725, 0.16502357, 0.19869995, 0.9638033 , 0.7735306 ,
0.6862997 , 0.7359276 , 0.97442615, 0.2690258 , 0.02489543], dtype=float32),
dtype=float32)
value
DeviceArray(5.5397797, dtype=float32)
has_aux
#
In some situations, we are interested in the intermediate values in a function, and has_aux=True
can be of great help. The constraint is that you must return values with the format of (loss, aux_data)
. For instance,
class LinearAux(bp.BrainPyObject):
def __init__(self):
super(LinearAux, self).__init__()
self.w = bm.Variable(bm.random.random((1, 10)))
self.b = bm.Variable(bm.zeros(1))
def update(self, x):
dot = bm.dot(self.w, x)
r = (dot + self.b).sum()
return r, (r, dot) # here the aux data is a tuple, includes the loss and the dot value.
# however, aux can be arbitrary complex.
l2 = LinearAux()
grad = bm.grad(l2.update, argnums=0, has_aux=True)
gradient, aux = grad(bm.ones(10))
gradient
Array(value=DeviceArray([0.7152523 , 0.83822143, 0.47706044, 0.23839808, 0.3606074 ,
0.14133751, 0.2397281 , 0.30746818, 0.39058363, 0.11630356], dtype=float32),
dtype=float32)
aux
(DeviceArray(3.8249607, dtype=float32),
Array(value=DeviceArray([3.8249607]), dtype=float32))
When multiple keywords (argnums
, grad_vars
, has_aux
orreturn_value
) are set simulatenously, the return format of the gradient function can be inspected through the corresponding API documentation brainpy.math.grad().
brainpy.math.grad()
#
brainpy.math.grad() takes a function/object (\(f : \mathbb{R}^n \to \mathbb{R}\)) as the input and returns a new function (\(\partial f(x) \to \mathbb{R}^n\)) which computes the gradient of the original function/object. It’s worthy to note that brainpy.math.grad()
only supports returning scalar values.
Pure functions#
For pure function, the gradient is taken with respect to the first argument:
def f(a, b):
return a * 2 + b
grad_f1 = bm.grad(f)
grad_f1(2., 1.)
DeviceArray(2., dtype=float32, weak_type=True)
However, this can be controlled via the argnums
argument.
grad_f2 = bm.grad(f, argnums=(0, 1))
grad_f2(2., 1.)
(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(1., dtype=float32, weak_type=True))
Class objects#
For a class object or a class bound function, the gradient is taken with respect to the provided grad_vars
and argnums
setting:
class F(bp.BrainPyObject):
def __init__(self):
super(F, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
vv = ab2 + c
return vv.mean()
f = F()
The grad_vars
can be a Array, or a list/tuple/dict of Array.
bm.grad(f, grad_vars=f.train_vars())(10.)
{'F0.a': DeviceArray([2.], dtype=float32),
'F0.b': DeviceArray([2.], dtype=float32)}
bm.grad(f, grad_vars=[f.a, f.b])(10.)
[DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32)]
If there are dynamically changed values in the gradient function, you can provide them in the dyn_vars
argument.
class F2(bp.BrainPyObject):
def __init__(self):
super(F2, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab = ab * 2
self.a.value = ab
return (ab + c).mean()
f2 = F2()
bm.grad(f2, grad_vars=f2.b)(10.)
DeviceArray([2.], dtype=float32)
Besides, if you are interested in the gradient of the input value, please use the argnums
argument. Then, the gradient function will return (grads_of_grad_vars, grads_of_args)
.
class F3(bp.BrainPyObject):
def __init__(self):
super(F3, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c, d):
ab = self.a * self.b
ab = ab * 2
return (ab + c * d).mean()
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=0)(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : [DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32)]
grad_of_args : 3.0
f3 = F3()
grads_of_gv, grad_of_args = bm.grad(f3, grad_vars=[f3.a, f3.b], argnums=(0, 1))(10., 3.)
print("grads_of_gv :", grads_of_gv)
print("grad_of_args :", grad_of_args)
grads_of_gv : [DeviceArray([2.], dtype=float32), DeviceArray([2.], dtype=float32)]
grad_of_args : (DeviceArray(3., dtype=float32, weak_type=True), DeviceArray(10., dtype=float32, weak_type=True))
Actually, it is recommended to provide all dynamically changed variables, whether or not they are updated in the gradient function, in the dyn_vars
argument.
Auxiliary data#
Usually, we want to get the loss value, or we want to return some intermediate variables during the gradient computation. In these situation, users can set has_aux=True
to return auxiliary data and set return_value=True
to return the loss value.
# return loss
grad, loss = bm.grad(f, grad_vars=f.a, return_value=True)(10.)
print('grad: ', grad)
print('loss: ', loss)
grad: [2.]
loss: 12.0
class F4(bp.BrainPyObject):
def __init__(self):
super(F4, self).__init__()
self.a = bm.TrainVar(bm.ones(1))
self.b = bm.TrainVar(bm.ones(1))
def __call__(self, c):
ab = self.a * self.b
ab2 = ab * 2
loss = (ab + c).mean()
return loss, (ab, ab2)
f4 = F4()
# return intermediate values
grad, aux_data = bm.grad(f4, grad_vars=f4.a, has_aux=True)(10.)
print('grad: ', grad)
print('aux_data: ', aux_data)
grad: [1.]
aux_data: (Array(value=DeviceArray([1.]), dtype=float32), Array(value=DeviceArray([2.]), dtype=float32))
Any function used to compute gradients through ``brainpy.math.grad()`` must return a scalar value. Otherwise an error will raise.
try:
bm.grad(lambda x: x)(bm.zeros(2))
except Exception as e:
print(type(e), e)
<class 'TypeError'> Gradient only defined for scalar-output functions. Output had shape: (2,).
# this is right
bm.grad(lambda x: x.mean())(bm.zeros(2))
Array(value=DeviceArray([0.5, 0.5]), dtype=float32)
brainpy.math.vector_grad()
#
If users want to take gradients for a vector-output values, please use the brainpy.math.vector_grad() function. For example,
def f(a, b):
return bm.sin(b) * a
Gradients for vectors#
# vectors
a = bm.arange(5.)
b = bm.random.random(5)
bm.vector_grad(f)(a, b)
Array(value=DeviceArray([0.01985658, 0.20870303, 0.2764193 , 0.32965127, 0.7212195 ], dtype=float32), dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(Array(value=DeviceArray([0.01985658, 0.20870303, 0.2764193 , 0.32965127, 0.7212195 ], dtype=float32), dtype=float32),
Array(value=DeviceArray([0. , 0.97797906, 1.9220742 , 2.8323083 , 2.7708263 ], dtype=float32), dtype=float32))
Gradients for matrices#
# matrix
a = bm.arange(6.).reshape((2, 3))
b = bm.random.random((2, 3))
bm.vector_grad(f, argnums=1)(a, b)
Array(value=DeviceArray([[0. , 0.9527361, 1.9759592],
[2.4942482, 2.2726011, 4.7790203]]),
dtype=float32)
bm.vector_grad(f, argnums=(0, 1))(a, b)
(Array(value=DeviceArray([[0.03127709, 0.3037993 , 0.15458442],
[0.5556503 , 0.82292485, 0.29400444]]),
dtype=float32),
Array(value=DeviceArray([[0. , 0.9527361, 1.9759592],
[2.4942482, 2.2726011, 4.7790203]]),
dtype=float32))
Similar to brainpy.math.grad() , brainpy.math.vector_grad()
also supports derivatives of variables in a class object. Here is a simple example.
class Test(bp.BrainPyObject):
def __init__(self):
super(Test, self).__init__()
self.x = bm.Variable(bm.ones(5))
self.y = bm.Variable(bm.ones(5))
def __call__(self):
return self.x ** 2 + self.y ** 3 + 10
t = Test()
bm.vector_grad(t, grad_vars=t.x)()
DeviceArray([2., 2., 2., 2., 2.], dtype=float32)
bm.vector_grad(t, grad_vars=(t.x, ))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),)
bm.vector_grad(t, grad_vars=(t.x, t.y))()
(DeviceArray([2., 2., 2., 2., 2.], dtype=float32),
DeviceArray([3., 3., 3., 3., 3.], dtype=float32))
Other operations like return_value
and has_aux
in brainpy.math.vector_grad() are the same as those in brainpy.math.grad() .
brainpy.math.jacobian()
#
Another way to take gradients of a vector-output value is using brainpy.math.jacobian(). brainpy.math.jacobian()
aims to automatically compute the Jacobian matrices \(\partial f(x) \in \mathbb{R}^{m \times n}\) by the given function \(f : \mathbb{R}^n \to \mathbb{R}^m\) at the given point of \(x \in \mathbb{R}^n\). Here, we will not go to the details of the implementation and usage of the brainpy.math.jacobian()
. Instead, we only show two examples about the pure function and class function.
Given the following function,
import jax.numpy as jnp
def f1(x, y):
a = 4 * x[1] ** 2 - 2 * x[2]
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
return r, a
_x = bm.array([1., 2., 3.])
_y = bm.array([10., 5.])
grads, vec, aux = bm.jacobian(f1, return_value=True, has_aux=True)(_x, _y)
grads
Array(value=DeviceArray([[10. , 0. , 0. ],
[ 0. , 0. , 25. ],
[ 0. , 16. , -2. ],
[ 1.6209068 , 0. , 0.84147096]]),
dtype=float32)
vec
DeviceArray([10. , 75. , 10. , 2.5244129], dtype=float32)
aux
DeviceArray(10., dtype=float32)
Given the following class objects,
class Test(bp.BrainPyObject):
def __init__(self):
super().__init__()
self.x = bm.Variable(bm.array([1., 2., 3.]))
def __call__(self, y):
a = self.x[0] * y[0]
b = 5 * self.x[2] * y[1]
c = 4 * self.x[1] ** 2 - 2 * self.x[2]
d = self.x[2] * jnp.sin(self.x[0])
r = jnp.asarray([a, b, c, d])
return r, (c, d)
t = Test()
f_grad = bm.jacobian(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)
(var_grads, arg_grads), value, aux = f_grad(_y)
var_grads
DeviceArray([[10. , 0. , 0. ],
[ 0. , 0. , 25. ],
[ 0. , 16. , -2. ],
[ 1.6209068 , 0. , 0.84147096]], dtype=float32)
arg_grads
Array(value=DeviceArray([[ 1., 0.],
[ 0., 15.],
[ 0., 0.],
[ 0., 0.]]),
dtype=float32)
value
DeviceArray([10. , 75. , 10. , 2.5244129], dtype=float32)
aux
(DeviceArray(10., dtype=float32), DeviceArray(2.5244129, dtype=float32))
For more details on automatical differentation, please see our API documentation.