Source code for torch_optimizer.aggmo

from typing import List, Tuple, Type, TypeVar, Union

import torch
from torch.optim.optimizer import Optimizer

from .types import OptFloat, OptLossClosure, Params

__all__ = ('AggMo',)


T = TypeVar('T', bound='AggMo')


[docs]class AggMo(Optimizer): r"""Implements Aggregated Momentum Gradient Descent. It has been proposed in `Aggregated Momentum: Stability Through Passive Damping`__ Example: >>> import torch_optimizer as optim >>> optimizer = optim.AggMo(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ https://arxiv.org/abs/1804.00325 Note: Reference code: https://github.com/AtheMathmo/AggMo/blob/master/aggmo.py # noqa """ def __init__( self, params: Params, lr: float = 1e-3, betas: Union[List[float], Tuple[float, ...]] = (0.0, 0.9, 0.99), weight_decay: float = 0, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) for i, beta in enumerate(betas): if not 0.0 <= beta < 1.0: msg = 'Invalid beta parameter at index 1: {}'.format(betas[i]) raise ValueError(msg) if weight_decay < 0.0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay) ) defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) super(AggMo, self).__init__(params, defaults) @classmethod def from_exp_form( cls: Type[T], params: Params, lr: float = 1e-3, a: float = 0.1, k: int = 3, weight_decay: float = 0, ) -> T: if lr <= 0.0: raise ValueError('Invalid parameter k: {}'.format(k)) betas = [1 - a ** i for i in range(k)] # type: List[float] return cls(params, lr, betas, weight_decay)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] betas = group['betas'] total_mom = float(len(betas)) for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(p.data, alpha=weight_decay) param_state = self.state[p] if 'momentum_buffer' not in param_state: param_state['momentum_buffer'] = {} for beta in betas: param_state['momentum_buffer'][ beta ] = torch.zeros_like(p.data) for beta in betas: buf = param_state['momentum_buffer'][beta] buf.mul_(beta).add_(d_p) p.data.sub_(buf, alpha=group['lr'] / total_mom) return loss