Optimizers

Gradient descent is one of the most popular algorithms to perform optimization. By far, gradient descent optimizers, combined with the loss function, are the key pieces that enable machine learning to work for your data. In this section, we are going to understand:

  • how to use optimizers in BrainPy?

  • how to customize your own optimizer?

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')
import matplotlib.pyplot as plt

Optimziers

The basic optimizer class in BrainPy is

bm.optimizers.Optimizer
brainpy.math.jax.optimizers.Optimizer

Following are some optimizers in BrainPy:

  • SGD

  • Momentum

  • Nesterov momentum

  • Adagrad

  • Adadelta

  • RMSProp

  • Adam

Users can also extent their own optimizers easily.

Generally, an Optimizer initialization receives a learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

  • lr can be a float, or an instance of bm.optimizers.Scheduler.

  • train_vars should be a dict of JaxArray.

Here we launch a SGD optimizer.

a = bm.ones((5, 4))
b = bm.zeros((3, 3))

op = bm.optimizers.SGD(lr=0.001, train_vars={'a': a, 'b': b})

When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.99970317, 0.99958736, 0.9991706 , 0.99929893],
                      [0.99986506, 0.9994412 , 0.9996797 , 0.9995855 ],
                      [0.99980134, 0.999285  , 0.99970514, 0.99927545],
                      [0.99907184, 0.9993837 , 0.99917775, 0.99953413],
                      [0.9999124 , 0.99908406, 0.9995969 , 0.9991523 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-5.8195234e-04, -4.3874790e-04, -3.3398748e-05],
                      [-5.7411409e-04, -7.0666044e-04, -9.4130711e-04],
                      [-7.1995187e-04, -1.1736620e-04, -9.5254736e-04]],            dtype=float32))

You can process gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}

grads_pre
{'a': JaxArray(DeviceArray([[0.99927866, 0.03028023, 0.8803668 , 0.64568734],
                       [0.64056313, 0.04791141, 0.7399359 , 0.87378347],
                       [0.96773326, 0.7771431 , 0.9618045 , 0.8374212 ],
                       [0.64901245, 0.24517596, 0.06224799, 0.6327405 ],
                       [0.31687486, 0.6385107 , 0.9160483 , 0.67039466]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.14722073, 0.52626574, 0.9817407 ],
                       [0.7333363 , 0.39472723, 0.82928896],
                       [0.7657701 , 0.93165004, 0.88332164]], dtype=float32))}
grads_post = bm.clip_by_norm(grads_pre, 1.)

grads_post
{'a': JaxArray(DeviceArray([[0.31979424, 0.00969043, 0.28173944, 0.20663615],
                       [0.20499626, 0.01533285, 0.23679803, 0.27963263],
                       [0.3096989 , 0.24870528, 0.30780157, 0.26799577],
                       [0.20770025, 0.07846245, 0.01992092, 0.20249282],
                       [0.1014079 , 0.20433943, 0.2931584 , 0.2145431 ]],            dtype=float32)),
 'b': JaxArray(DeviceArray([[0.0666547 , 0.23826863, 0.4444865 ],
                       [0.33202055, 0.17871413, 0.37546346],
                       [0.34670505, 0.4218078 , 0.39992693]], dtype=float32))}
op.update(grads_post)

print('a:', a)
print('b:', b)
a: JaxArray(DeviceArray([[0.9993834 , 0.99957764, 0.99888885, 0.9990923 ],
                      [0.9996601 , 0.9994259 , 0.9994429 , 0.9993059 ],
                      [0.99949163, 0.99903625, 0.99939734, 0.99900746],
                      [0.9988641 , 0.99930525, 0.99915785, 0.99933165],
                      [0.999811  , 0.99887973, 0.99930376, 0.9989378 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-0.00064861, -0.00067702, -0.00047789],
                      [-0.00090613, -0.00088537, -0.00131677],
                      [-0.00106666, -0.00053917, -0.00135247]], dtype=float32))

Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, you should add variables in Optimzier.vars() onto the dyn_vars in bm.jit().

op.vars()  # SGD optimzier only has an iterable `step` variable to record the training step
{'Constant0.step': Variable(DeviceArray([2], dtype=int32))}
bm.optimizers.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables
{'Constant1.step': Variable(DeviceArray([0], dtype=int32)),
 'Momentum0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Momentum0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}
bm.optimizers.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables
{'Constant2.step': Variable(DeviceArray([0], dtype=int32)),
 'Adam0.a_m': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_m': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32)),
 'Adam0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.],
                       [0., 0., 0., 0.]], dtype=float32)),
 'Adam0.b_v': Variable(DeviceArray([[0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]], dtype=float32))}

Creating a custom optimizer

If you intend to create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

  • __init__(): init function which receives learning rate (lr) and trainable variables (train_vars).

  • update(grads): update function to compute the updated parameters.

For example,

class CustomizeOp(bm.optimizers.Optimizer):
    def __init__(self, lr, train_vars, *params, **other_params):
        super(CustomizeOp, self).__init__(lr, train_vars)
        
        # customize your initialization
        
    def update(self, grads):
        # customize your update logic
        pass

Schedulers

Scheduler seeks to adjust the learning rate during training by reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

For example, we setup an exponential decay scheduler, in which the learning rate will decay exponentially along the train step.

sc = bm.optimizers.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)
def show(steps, rates):
    plt.plot(steps, rates)
    plt.xlabel('Train Step')
    plt.ylabel('Learning Rate')
    plt.show()
steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)
../_images/optimizers_29_0.png

After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr
<brainpy.math.jax.optimizers.Constant at 0x2ab375a3700>

One can get the current learning rate value by calling Scheduler.__call__(i=None).

  • If i is not provided, the learning rate value will be evaluated at the built-in training step.

  • Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()
0.001

In BrainPy, several common used learning rate schedulers are used:

  • Constant

  • ExponentialDecay

  • InverseTimeDecay

  • PolynomialDecay

  • PiecewiseConstant

# InverseTimeDecay scheduler

rates = bm.optimizers.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)
../_images/optimizers_35_0.png
# PolynomialDecay scheduler

rates = bm.optimizers.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)
../_images/optimizers_36_0.png

Creating a custom scheduler

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

  • __init__(): the init function.

  • __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bm.optimizers.Scheduler):
    def __init__(self, lr, *params, **other_params):
        super(CustomizeScheduler, self).__init__(lr)
        
        # customize your initialization
        
    def __call__(self, i=None):
        # customize your update logic
        pass