Optimizers

@Chaoming Wang @Xiaoyu Chen

Gradient descent is one of the most popular optimization methods. At present, gradient descent optimizers, combined with the loss function, are the key to machine learning, especially deep learning. In this section, we are going to understand:

  • how to use optimizers in BrainPy?

  • how to customize your own optimizer?

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
import matplotlib.pyplot as plt

Optimizers in BrainPy

The basic optimizer class in BrainPy is brainpy.math.optimizers.Optimizer, which inludes the following optimizers:

  • SGD

  • Momentum

  • Nesterov momentum

  • Adagrad

  • Adadelta

  • RMSProp

  • Adam

All supported optimizers can be inspected through the brainpy.math.optimizers APIs.

Generally, an optimizer initialization receives the learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

  • lr can be a float, or an instance of bm.optimizers.Scheduler.

  • train_vars should be a dict of JaxArray.

Here we launch a SGD optimizer.

a = bm.ones((5, 4))
b = bm.zeros((3, 3))

op = bm.optimizers.SGD(lr=0.001, train_vars={'a': a, 'b': b})

When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.9997439 , 0.9991836 , 0.9999782 , 0.9990992 ],
                      [0.999076  , 0.9997612 , 0.99925077, 0.99903256],
                      [0.9996492 , 0.9998097 , 0.99977213, 0.99905187],
                      [0.99963087, 0.99951845, 0.99903643, 0.99996334],
                      [0.99986696, 0.99989676, 0.99937785, 0.99970794]],            dtype=float32))
b: JaxArray(DeviceArray([[-7.7561024e-05, -4.1913200e-04, -6.5632869e-04],
                      [-3.4492972e-05, -5.1765458e-04, -9.3037548e-04],
                      [-9.2792397e-05, -3.1649830e-05, -5.2235392e-04]],            dtype=float32))

You can process the gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}

grads_pre
{'a': JaxArray(DeviceArray([[0.62677693, 0.56814206, 0.12360227, 0.257097  ],
                       [0.8980639 , 0.33391   , 0.1802653 , 0.26349783],
                       [0.9989817 , 0.25854266, 0.8259059 , 0.71850395],
                       [0.6676611 , 0.5614054 , 0.7707871 , 0.16712415],
                       [0.21876848, 0.8567476 , 0.7716671 , 0.7988616 ]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.4645934 , 0.2903055 , 0.08017159],
                       [0.6825682 , 0.0905968 , 0.8062532 ],
                       [0.32745683, 0.7631104 , 0.03143311]], dtype=float32))}
grads_post = bm.clip_by_norm(grads_pre, 1.)

grads_post
{'a': JaxArray(DeviceArray([[0.2291931 , 0.20775212, 0.04519756, 0.09401249],
                       [0.3283944 , 0.12210064, 0.06591749, 0.09635308],
                       [0.36529696, 0.09454112, 0.30200845, 0.26273486],
                       [0.24414317, 0.20528874, 0.2818532 , 0.06111218],
                       [0.07999692, 0.31328633, 0.282175  , 0.2921192 ]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.31898955, 0.19932356, 0.05504576],
                       [0.46865088, 0.0622037 , 0.55357295],
                       [0.22483166, 0.5239511 , 0.02158195]], dtype=float32))}
op.update(grads_post)

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.9995147 , 0.9989758 , 0.999933  , 0.9990052 ],
                      [0.9987476 , 0.9996391 , 0.99918485, 0.9989362 ],
                      [0.9992839 , 0.99971515, 0.9994701 , 0.99878913],
                      [0.9993867 , 0.9993132 , 0.99875456, 0.99990225],
                      [0.999787  , 0.9995835 , 0.9990957 , 0.9994158 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-0.00039655, -0.00061846, -0.00071137],
                      [-0.00050314, -0.00057986, -0.00148395],
                      [-0.00031762, -0.0005556 , -0.00054394]], dtype=float32))

Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, your dyn_vars in bm.jit() should include variables in Optimzier.vars().

op.vars()  # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable(DeviceArray([2], dtype=int32))}
bm.optimizers.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables
{'Constant1.step': Variable(DeviceArray([0], dtype=int32)),
 'Momentum0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Momentum0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}
bm.optimizers.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables
{'Constant2.step': Variable(DeviceArray([0], dtype=int32)),
 'Adam0.a_m': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_m': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32)),
 'Adam0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}

Creating A Self-Customized Optimizer

To create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

  • __init__(): init function that receives the learning rate (lr) and trainable variables (train_vars). Do not forget to register your dynamical changed variables into implicit_vars.

  • update(grads): update function that computes the updated parameters.

The general structure is shown below:

class CustomizeOp(bm.optimizers.Optimizer):
    def __init__(self, lr, train_vars, *params, **other_params):
        super(CustomizeOp, self).__init__(lr, train_vars)
        
        # customize your initialization
        
    def update(self, grads):
        # customize your update logic
        pass

Schedulers

Scheduler seeks to adjust the learning rate during training through reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

Here we set up an exponential decay scheduler, in which the learning rate will decay exponentially along the training step.

sc = bm.optimizers.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
    plt.plot(steps, rates)
    plt.xlabel('Train Step')
    plt.ylabel('Learning Rate')
    plt.show()
steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)
../_images/optimizers_28_0.png

After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr
<brainpy.math.optimizers.Constant at 0x14899c76dc0>

One can get the current learning rate value by calling Scheduler.__call__(i=None).

  • If i is not provided, the learning rate value will be evaluated at the built-in training step.

  • Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()
0.001

In BrainPy, several commonly used learning rate schedulers are used:

  • Constant

  • ExponentialDecay

  • InverseTimeDecay

  • PolynomialDecay

  • PiecewiseConstant

For more details, please see the brainpy.math.optimizers APIs.

# InverseTimeDecay scheduler

rates = bm.optimizers.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)
../_images/optimizers_34_0.png
# PolynomialDecay scheduler

rates = bm.optimizers.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)
../_images/optimizers_35_0.png

Creating a Self-Customized Scheduler

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

  • __init__(): the init function.

  • __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bm.optimizers.Scheduler):
    def __init__(self, lr, *params, **other_params):
        super(CustomizeScheduler, self).__init__(lr)
        
        # customize your initialization
        
    def __call__(self, i=None):
        # customize your update logic
        pass