Source code for torch_optimizer.accsgd
import copy
from torch.optim.optimizer import Optimizer
from .types import OptFloat, OptLossClosure, Params
__all__ = ('AccSGD',)
[docs]class AccSGD(Optimizer):
r"""Implements AccSGD algorithm.
It has been proposed in `On the insufficiency of existing momentum
schemes for Stochastic Optimization`__ and `Accelerating Stochastic
Gradient Descent For Least Squares Regression`__
Arguments:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate (default: 1e-3)
kappa: ratio of long to short step (default: 1000)
xi: statistical advantage parameter (default: 10)
small_const: any value <=1 (default: 0.7)
weight_decay: weight decay (L2 penalty) (default: 0)
Example:
>>> import torch_optimizer as optim
>>> optimizer = optim.AccSGD(model.parameters(), lr=0.1)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ https://arxiv.org/abs/1704.08227
__ https://arxiv.org/abs/1803.05591
Note:
Reference code: https://github.com/rahulkidambi/AccSGD
"""
def __init__(
self,
params: Params,
lr: float = 1e-3,
kappa: float = 1000.0,
xi: float = 10.0,
small_const: float = 0.7,
weight_decay: float = 0,
) -> None:
if lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if weight_decay < 0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)
defaults = dict(
lr=lr,
kappa=kappa,
xi=xi,
small_const=small_const,
weight_decay=weight_decay,
)
super(AccSGD, self).__init__(params, defaults)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
large_lr = (group['lr'] * group['kappa']) / (group['small_const'])
alpha = 1.0 - (
(group['small_const'] * group['small_const'] * group['xi'])
/ group['kappa']
)
beta = 1.0 - alpha
zeta = group['small_const'] / (group['small_const'] + beta)
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(p.data, alpha=weight_decay)
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = copy.deepcopy(p.data)
buf = param_state['momentum_buffer']
buf.mul_((1.0 / beta) - 1.0)
buf.add_(d_p, alpha=-large_lr)
buf.add_(p.data)
buf.mul_(beta)
p.data.add_(d_p, alpha=-group['lr'])
p.data.mul_(zeta)
p.data.add_(buf, alpha=1.0 - zeta)
return loss