Source code for torch_optimizer.aggmo
from typing import List, Tuple, Type, TypeVar, Union
import torch
from torch.optim.optimizer import Optimizer
from .types import OptFloat, OptLossClosure, Params
__all__ = ('AggMo',)
T = TypeVar('T', bound='AggMo')
[docs]class AggMo(Optimizer):
r"""Implements Aggregated Momentum Gradient Descent.
It has been proposed in `Aggregated Momentum: Stability Through Passive
Damping`__
Example:
>>> import torch_optimizer as optim
>>> optimizer = optim.AggMo(model.parameters(), lr=0.1)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ https://arxiv.org/abs/1804.00325
Note:
Reference code: https://github.com/AtheMathmo/AggMo/blob/master/aggmo.py # noqa
"""
def __init__(
self,
params: Params,
lr: float = 1e-3,
betas: Union[List[float], Tuple[float, ...]] = (0.0, 0.9, 0.99),
weight_decay: float = 0,
) -> None:
if lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
for i, beta in enumerate(betas):
if not 0.0 <= beta < 1.0:
msg = 'Invalid beta parameter at index 1: {}'.format(betas[i])
raise ValueError(msg)
if weight_decay < 0.0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super(AggMo, self).__init__(params, defaults)
@classmethod
def from_exp_form(
cls: Type[T],
params: Params,
lr: float = 1e-3,
a: float = 0.1,
k: int = 3,
weight_decay: float = 0,
) -> T:
if lr <= 0.0:
raise ValueError('Invalid parameter k: {}'.format(k))
betas = [1 - a ** i for i in range(k)] # type: List[float]
return cls(params, lr, betas, weight_decay)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
betas = group['betas']
total_mom = float(len(betas))
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(p.data, alpha=weight_decay)
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = {}
for beta in betas:
param_state['momentum_buffer'][
beta
] = torch.zeros_like(p.data)
for beta in betas:
buf = param_state['momentum_buffer'][beta]
buf.mul_(beta).add_(d_p)
p.data.sub_(buf, alpha=group['lr'] / total_mom)
return loss