# Optimizers

Gradient descent is one of the most popular algorithms to perform optimization. By far, gradient descent optimizers, combined with the loss function, are the key pieces that enable machine learning to work for your data. In this section, we are going to understand:

• how to use optimizers in BrainPy?

• how to customize your own optimizer?

import brainpy as bp
import brainpy.math.jax as bm

bp.math.use_backend('jax')

import matplotlib.pyplot as plt


## Optimziers

The basic optimizer class in BrainPy is

bm.optimizers.Optimizer

brainpy.math.jax.optimizers.Optimizer


Following are some optimizers in BrainPy:

• SGD

• Momentum

• Nesterov momentum

• RMSProp

Users can also extent their own optimizers easily.

Generally, an Optimizer initialization receives a learning rate lr, the trainable variables train_vars, and other hyperparameters for the specific optimizer.

• lr can be a float, or an instance of bm.optimizers.Scheduler.

• train_vars should be a dict of JaxArray.

Here we launch a SGD optimizer.

a = bm.ones((5, 4))
b = bm.zeros((3, 3))

op = bm.optimizers.SGD(lr=0.001, train_vars={'a': a, 'b': b})


When you try to update the parameters, you must provide the corresponding gradients for each parameter in the update() method.

op.update({'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)})

print('a:', a)
print('b:', b)

a: JaxArray(DeviceArray([[0.99970317, 0.99958736, 0.9991706 , 0.99929893],
[0.99986506, 0.9994412 , 0.9996797 , 0.9995855 ],
[0.99980134, 0.999285  , 0.99970514, 0.99927545],
[0.99907184, 0.9993837 , 0.99917775, 0.99953413],
[0.9999124 , 0.99908406, 0.9995969 , 0.9991523 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-5.8195234e-04, -4.3874790e-04, -3.3398748e-05],
[-5.7411409e-04, -7.0666044e-04, -9.4130711e-04],
[-7.1995187e-04, -1.1736620e-04, -9.5254736e-04]],            dtype=float32))


You can process gradients before applying them. For example, we clip the graidents by the maximum L2-norm.

grads_pre = {'a': bm.random.random(a.shape), 'b': bm.random.random(b.shape)}


{'a': JaxArray(DeviceArray([[0.99927866, 0.03028023, 0.8803668 , 0.64568734],
[0.64056313, 0.04791141, 0.7399359 , 0.87378347],
[0.96773326, 0.7771431 , 0.9618045 , 0.8374212 ],
[0.64901245, 0.24517596, 0.06224799, 0.6327405 ],
[0.31687486, 0.6385107 , 0.9160483 , 0.67039466]],            dtype=float32)),
'b': JaxArray(DeviceArray([[0.14722073, 0.52626574, 0.9817407 ],
[0.7333363 , 0.39472723, 0.82928896],
[0.7657701 , 0.93165004, 0.88332164]], dtype=float32))}

grads_post = bm.clip_by_norm(grads_pre, 1.)


{'a': JaxArray(DeviceArray([[0.31979424, 0.00969043, 0.28173944, 0.20663615],
[0.20499626, 0.01533285, 0.23679803, 0.27963263],
[0.3096989 , 0.24870528, 0.30780157, 0.26799577],
[0.20770025, 0.07846245, 0.01992092, 0.20249282],
[0.1014079 , 0.20433943, 0.2931584 , 0.2145431 ]],            dtype=float32)),
'b': JaxArray(DeviceArray([[0.0666547 , 0.23826863, 0.4444865 ],
[0.33202055, 0.17871413, 0.37546346],
[0.34670505, 0.4218078 , 0.39992693]], dtype=float32))}

op.update(grads_post)

print('a:', a)
print('b:', b)

a: JaxArray(DeviceArray([[0.9993834 , 0.99957764, 0.99888885, 0.9990923 ],
[0.9996601 , 0.9994259 , 0.9994429 , 0.9993059 ],
[0.99949163, 0.99903625, 0.99939734, 0.99900746],
[0.9988641 , 0.99930525, 0.99915785, 0.99933165],
[0.999811  , 0.99887973, 0.99930376, 0.9989378 ]],            dtype=float32))
b: JaxArray(DeviceArray([[-0.00064861, -0.00067702, -0.00047789],
[-0.00090613, -0.00088537, -0.00131677],
[-0.00106666, -0.00053917, -0.00135247]], dtype=float32))


Note

Optimizer usually has their own dynamically changed variables. If you JIT a function whose logic contains optimizer update, you should add variables in Optimzier.vars() onto the dyn_vars in bm.jit().

op.vars()  # SGD optimzier only has an iterable step variable to record the training step

{'Constant0.step': Variable(DeviceArray([2], dtype=int32))}

bm.optimizers.Momentum(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Momentum has velocity variables

{'Constant1.step': Variable(DeviceArray([0], dtype=int32)),
'Momentum0.a_v': Variable(DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
'Momentum0.b_v': Variable(DeviceArray([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))}

bm.optimizers.Adam(lr=0.001, train_vars={'a': a, 'b': b}).vars()  # Adam has more variables

{'Constant2.step': Variable(DeviceArray([0], dtype=int32)),
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)),
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
[0., 0., 0.],
[0., 0., 0.]], dtype=float32))}


## Creating a custom optimizer

If you intend to create your own optimization algorithm, simply inherit from bm.optimizers.Optimizer class and override the following methods:

• __init__(): init function which receives learning rate (lr) and trainable variables (train_vars).

• update(grads): update function to compute the updated parameters.

For example,

class CustomizeOp(bm.optimizers.Optimizer):
def __init__(self, lr, train_vars, *params, **other_params):
super(CustomizeOp, self).__init__(lr, train_vars)

pass


## Schedulers

Scheduler seeks to adjust the learning rate during training by reducing the learning rate according to a pre-defined schedule. Common learning rate schedules include time-based decay, step decay and exponential decay.

For example, we setup an exponential decay scheduler, in which the learning rate will decay exponentially along the train step.

sc = bm.optimizers.ExponentialDecay(lr=0.1, decay_steps=2, decay_rate=0.99)

def show(steps, rates):
plt.plot(steps, rates)
plt.xlabel('Train Step')
plt.ylabel('Learning Rate')
plt.show()

steps = bm.arange(1000)
rates = sc(steps)

show(steps, rates)


After Optimizer initialization, the learning rate self.lr will always be an instance of bm.optimizers.Scheduler. A scalar float learning rate initialization will result in a Constant scheduler.

op.lr

<brainpy.math.jax.optimizers.Constant at 0x2ab375a3700>


One can get the current learning rate value by calling Scheduler.__call__(i=None).

• If i is not provided, the learning rate value will be evaluated at the built-in training step.

• Otherwise, the learning rate value will be evaluated at the given step i.

op.lr()

0.001


In BrainPy, several common used learning rate schedulers are used:

• Constant

• ExponentialDecay

• InverseTimeDecay

• PolynomialDecay

• PiecewiseConstant

# InverseTimeDecay scheduler

rates = bm.optimizers.InverseTimeDecay(lr=0.01, decay_steps=10, decay_rate=0.999)(steps)
show(steps, rates)

# PolynomialDecay scheduler

rates = bm.optimizers.PolynomialDecay(lr=0.01, decay_steps=10, final_lr=0.0001)(steps)
show(steps, rates)


## Creating a custom scheduler

If users try to implement their own scheduler, simply inherit from bm.optimizers.Scheduler class and override the following methods:

• __init__(): the init function.

• __call__(i=None): the learning rate value evalution.

class CustomizeScheduler(bm.optimizers.Scheduler):
def __init__(self, lr, *params, **other_params):
super(CustomizeScheduler, self).__init__(lr)