# -*- coding: utf-8 -*-
import warnings
from functools import partial
from typing import Sequence, Union
import jax
import jax.numpy as jnp
import brainpy.math as bm
from brainpy import check
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy.errors import MathError
# learning rate schedules #
# ----------------------- #
[docs]
def make_schedule(scalar_or_schedule):
if isinstance(scalar_or_schedule, Scheduler):
return scalar_or_schedule
elif isinstance(scalar_or_schedule, (int, float, bm.Variable)):
return Constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule))
[docs]
class Scheduler(BrainPyObject):
"""The learning rate scheduler."""
def __init__(self, lr: Union[float, bm.Variable], last_epoch: int = -1):
super(Scheduler, self).__init__()
assert bm.ndim(lr) == 0
self.lr = lr
check.is_integer(last_epoch, allow_none=False, min_bound=-1)
self.last_epoch = bm.Variable(jnp.asarray(last_epoch))
def set_value(self, learning_rate):
self.lr = learning_rate
def step_epoch(self):
self.last_epoch += 1
def step_call(self):
pass
def __repr__(self):
return f'{self.__class__.__name__}(lr={self.lr}, last_epoch={self.last_epoch.value})'
def __call__(self, i=None):
raise NotImplementedError
[docs]
class Constant(Scheduler):
def __call__(self, i=None):
return self.lr
class CallBasedScheduler(Scheduler):
def __init__(self, lr: Union[float, bm.Variable], last_epoch: int = -1, last_call: int = -1):
super().__init__(lr=lr, last_epoch=last_epoch)
check.is_integer(last_call, allow_none=False, min_bound=-1)
self.last_call = bm.Variable(jnp.asarray(last_call))
def step_call(self):
self.last_call += 1
def __repr__(self):
return f'{self.__class__.__name__}(lr={self.lr}, last_call={self.last_call.value})'
[docs]
class StepLR(Scheduler):
"""Decays the learning rate of each parameter group by gamma every
`step_size` epochs.
Parameters
----------
lr: float
Initial learning rate.
step_size: int
Period of learning rate decay.
gamma: float
Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch: int
The index of last epoch. Default: -1.
"""
def __init__(
self,
lr: float,
step_size: int,
gamma: float = 0.1,
last_epoch: int = -1
):
super().__init__(lr=lr, last_epoch=last_epoch)
self.step_size = check.is_integer(step_size, min_bound=1, allow_none=False)
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1., allow_int=False)
def __call__(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
def __repr__(self):
return (f'{self.__class__.__name__}(lr={self.lr}, '
f'step_size={self.step_size}, gamma={self.gamma}, '
f'last_epoch={self.last_epoch})')
[docs]
class MultiStepLR(Scheduler):
"""Decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the learning rate from outside
this scheduler. When last_epoch=-1, sets initial lr as lr.
Parameters
----------
lr: float
Initial learning rate.
milestones: sequence of int
List of epoch indices. Must be increasing.
gamma: float
Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch: int
The index of last epoch. Default: -1.
"""
def __init__(
self,
lr: float,
milestones: Sequence[int],
gamma: float = 0.1,
last_epoch: int = -1
):
super().__init__(lr=lr, last_epoch=last_epoch)
self.milestones = check.is_sequence(milestones, elem_type=int, allow_none=False)
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1., allow_int=False)
@bm.cls_jit(inline=True)
def __call__(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
p = bm.ifelse([i < m for m in self.milestones],
list(range(0, len(self.milestones))) + [len(self.milestones)])
return self.lr * self.gamma ** p
def __repr__(self):
return (f'{self.__class__.__name__}(lr={self.lr}, '
f'milestones={self.milestones}, gamma={self.gamma}, '
f'last_epoch={self.last_epoch})')
[docs]
class CosineAnnealingLR(Scheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
is defined recursively, the learning rate can be simultaneously modified
outside this scheduler by other operators. If the learning rate is set
solely by this scheduler, the learning rate at each step becomes:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Parameters
----------
lr: float
Initial learning rate.
T_max: int
Maximum number of iterations.
eta_min: float
Minimum learning rate. Default: 0.
last_epoch: int
The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self,
lr: float,
T_max: int,
eta_min: float = 0.,
last_epoch: int = -1, ):
super().__init__(lr=lr, last_epoch=last_epoch)
self._init_epoch = last_epoch
self.T_max = check.is_integer(T_max, min_bound=1)
self.eta_min = eta_min
@bm.cls_jit(inline=True)
def __call__(self, i=None):
i = (self.last_epoch + 1) if i is None else i
return (self.eta_min + (self.lr - self.eta_min) *
(1 + jnp.cos(jnp.pi * i / self.T_max)) / 2)
[docs]
class CosineAnnealingWarmRestarts(CallBasedScheduler):
"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Parameters
----------
lr: float
Initial learning rate.
num_call_per_epoch: int
The number the scheduler to call in each epoch.
This usually means the number of batch in each epoch training.
T_0: int
Number of iterations for the first restart.
T_mult: int
A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min: float
Minimum learning rate. Default: 0.
last_call: int
The index of last call. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self,
lr: float,
num_call_per_epoch: int,
T_0: int,
T_mult: int = 1,
eta_min: float = 0.,
last_epoch: int = -1,
last_call: int = -1):
super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
self.T_mult = T_mult
self.eta_min = eta_min
self.T_0 = T_0
self.num_call_per_epoch = num_call_per_epoch
def _cond1(self, epoch):
if self.T_mult == 1:
T_cur = epoch % self.T_0
T_i = self.T_0
else:
n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
T_i = self.T_0 * self.T_mult ** n
return T_cur, T_i
def _cond2(self, epoch):
return epoch, self.T_0
@bm.cls_jit(inline=True)
def __call__(self, i=None):
i = (self.last_call + 1) if i is None else i
epoch = i / self.num_call_per_epoch
T_cur, T_i = jax.lax.cond(epoch >= self.T_0,
self._cond1,
self._cond2,
epoch)
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
@bm.cls_jit(inline=True)
def current_epoch(self, i=None):
i = (self.last_call + 1) if i is None else i
return jnp.floor(i / self.num_call_per_epoch)
[docs]
class ExponentialLR(Scheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr.
Parameters
----------
lr: float
Initial learning rate.
gamma: float
Multiplicative factor of learning rate decay.
last_epoch: int
The index of last epoch. Default: -1.
"""
def __init__(self,
lr: float,
gamma: float,
last_epoch: int = -1):
super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
self.gamma = check.is_float(gamma, min_bound=0., max_bound=1.)
def __call__(self, i: int = None):
i = (self.last_epoch + 1) if i is None else i
return self.lr * self.gamma ** i
def __repr__(self):
return f'{self.__class__.__name__}(lr={self.lr}, last_epoch={self.last_epoch}, gamma={self.gamma})'
[docs]
class ExponentialDecayLR(CallBasedScheduler):
def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
self.decay_steps = decay_steps
self.decay_rate = decay_rate
def __call__(self, i=None):
i = (self.last_call.value + 1) if i is None else i
return self.lr * self.decay_rate ** (i / self.decay_steps)
def __repr__(self):
return (f'{self.__class__.__name__}({self.lr}, '
f'decay_steps={self.decay_steps}, '
f'decay_rate={self.decay_rate}), '
f'last_call={self.last_call.value})')
[docs]
class ExponentialDecay(ExponentialDecayLR):
def __init__(self, *args, **kwargs):
super(ExponentialDecay, self).__init__(*args, **kwargs)
warnings.warn("ExponentialDecay is abandoned, please use ExponentialDecayLR insteadly.")
[docs]
class InverseTimeDecayLR(ExponentialDecayLR):
def __init__(self, lr, decay_steps, decay_rate, staircase=False,
last_epoch: int = -1, last_call: int = -1):
super(InverseTimeDecayLR, self).__init__(lr, decay_steps, decay_rate,
last_epoch=last_epoch,
last_call=last_call)
self.staircase = staircase
def __call__(self, i=None):
i = (self.last_call.value + 1) if i is None else i
if self.staircase:
return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
else:
return self.lr / (1 + self.decay_rate * i / self.decay_steps)
def __repr__(self):
return f'{self.__class__.__name__}({self.lr}, staircase={self.staircase})'
[docs]
class InverseTimeDecay(InverseTimeDecayLR):
def __init__(self, *args, **kwargs):
super(InverseTimeDecay, self).__init__(*args, **kwargs)
warnings.warn("InverseTimeDecay is abandoned, please use InverseTimeDecayLR insteadly.")
[docs]
class PolynomialDecayLR(CallBasedScheduler):
def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
self.decay_steps = decay_steps
self.final_lr = final_lr
self.power = power
def __call__(self, i=None):
i = (self.last_call.value + 1) if i is None else i
i = jnp.minimum(i, self.decay_steps)
step_mult = (1 - i / self.decay_steps) ** self.power
return step_mult * (self.lr - self.final_lr) + self.final_lr
def __repr__(self):
return (f'{self.__class__.__name__}({self.lr}, '
f'last_call={self.last_call.value}, '
f'decay_steps={self.decay_steps}, '
f'final_lr={self.final_lr}, '
f'power={self.power})')
[docs]
class PolynomialDecay(PolynomialDecayLR):
def __init__(self, *args, **kwargs):
super(PolynomialDecay, self).__init__(*args, **kwargs)
warnings.warn("PolynomialDecay is abandoned, please use PolynomialDecayLR insteadly.")
[docs]
class PiecewiseConstantLR(CallBasedScheduler):
def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
super(PiecewiseConstantLR, self).__init__(0., last_epoch=last_epoch, last_call=last_call)
boundaries = jnp.array(boundaries)
values = jnp.array(values)
if not boundaries.ndim == values.ndim == 1:
raise MathError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise MathError("boundaries length must be one shorter than values length")
self.boundaries = boundaries
self.values = values
def __call__(self, i=None):
i = (self.last_call.value + 1) if i is None else i
return self.values[jnp.sum(i > self.boundaries)]
[docs]
class PiecewiseConstant(PiecewiseConstantLR):
def __init__(self, *args, **kwargs):
super(PiecewiseConstant, self).__init__(*args, **kwargs)
warnings.warn("PiecewiseConstant is abandoned, please use PiecewiseConstantLR insteadly.")