brainpy.optim.CosineAnnealingWarmRestarts#
- class brainpy.optim.CosineAnnealingWarmRestarts(lr, num_call_per_epoch, T_0, T_mult=1, eta_min=0.0, last_epoch=-1, last_call=-1)[source]#
- Set the learning rate of each parameter group using a cosine annealing
schedule, where \(\eta_{max}\) is set to the initial lr, \(T_{cur}\) is the number of epochs since the last restart and \(T_{i}\) is the number of epochs between two warm restarts in SGDR:
\[\eta_t = \eta_{min} +\]- rac{1}{2}(eta_{max} - eta_{min})left(1 +
cosleft(
rac{T_{cur}}{T_{i}}pi ight) ight)
When \(T_{cur}=T_{i}\), set \(\eta_t = \eta_{min}\). When \(T_{cur}=0\) after restart, set \(\eta_t=\eta_{max}\).
It has been proposed in SGDR: Stochastic Gradient Descent with Warm Restarts.
- lr: float
Initial learning rate.
- num_call_per_epoch: int
The number the scheduler to call in each epoch. This usually means the number of batch in each epoch training.
- T_0: int
Number of iterations for the first restart.
- T_mult: int
A factor increases \(T_{i}\) after a restart. Default: 1.
- eta_min: float
Minimum learning rate. Default: 0.
- last_call: int
The index of last call. Default: -1.
Methods
__init__
(lr, num_call_per_epoch, T_0[, ...])cpu
()Move all variable into the CPU device.
cuda
()Move all variables into the GPU device.
current_epoch
([i])load_state_dict
(state_dict[, warn, compatible])Copy parameters and buffers from
state_dict
into this module and its descendants.load_states
(filename[, verbose])Load the model states.
nodes
([method, level, include_self])Collect all children nodes.
register_implicit_nodes
(*nodes[, node_cls])register_implicit_vars
(*variables[, var_cls])save_states
(filename[, variables])Save the model states.
state_dict
()Returns a dictionary containing a whole state of the module.
step_call
()step_epoch
()to
(device)Moves all variables into the given device.
tpu
()Move all variables into the TPU device.
train_vars
([method, level, include_self])The shortcut for retrieving all trainable variables.
tree_flatten
()Flattens the object as a PyTree.
tree_unflatten
(aux, dynamic_values)Unflatten the data to construct an object of this class.
unique_name
([name, type_])Get the unique name for this object.
vars
([method, level, include_self, ...])Collect all variables in this node and the children nodes.
Attributes
name
Name of the model.