Source code for torch_optimizer.adafactor

import math
from typing import Any, Dict, Tuple

import torch
from torch.optim.optimizer import Optimizer

from .types import OptFloat, OptLossClosure, Params, State

Eps2 = Tuple[float, float]
ParamGroup = Dict[str, Any]

[docs]class Adafactor(Optimizer): """Implements Adafactor algorithm. It has been proposed in: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: external learning rate (default: None) eps2: regularization constans for square gradient and parameter scale respectively (default: (1e-30, 1e-3)) clip_threshold: threshold of root mean square of final gradient update (default: 1.0) decay_rate: coefficient used to compute running averages of square gradient (default: -0.8) beta1: coefficient used for computing running averages of gradient (default: None) weight_decay: weight decay (L2 penalty) (default: 0) scale_parameter: if true, learning rate is scaled by root mean square of parameter (default: True) relative_step: if true, time-dependent learning rate is computed instead of external learning rate (default: True) warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used (default: False) Example: >>> import torch_optimizer as optim >>> optimizer = optim.Adafactor(model.parameters()) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ Note: Reference code: # noqa """ def __init__( self, params: Params, lr: OptFloat = None, eps2: Eps2 = (1e-30, 1e-3), clip_threshold: float = 1.0, decay_rate: float = -0.8, beta1: OptFloat = None, weight_decay: float = 0.0, scale_parameter: bool = True, relative_step: bool = True, warmup_init: bool = False, ): if lr is not None and lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if weight_decay < 0.0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay) ) defaults = dict( lr=lr, eps2=eps2, clip_threshold=clip_threshold, decay_rate=decay_rate, beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, relative_step=relative_step, warmup_init=warmup_init, ) super(Adafactor, self).__init__(params, defaults) def _get_lr(self, param_group: ParamGroup, param_state: State) -> float: rel_step_sz = param_group['lr'] if param_group['relative_step']: min_step = ( 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 ) rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step'])) param_scale = 1.0 if param_group['scale_parameter']: param_scale = max(param_group['eps2'][1], param_state['RMS']) return param_scale * rel_step_sz def _get_options( self, param_group: ParamGroup, param_shape: Tuple[int, ...] ) -> Tuple[bool, bool]: factored = len(param_shape) >= 2 use_first_moment = param_group['beta1'] is not None return factored, use_first_moment def _rms(self, tensor: torch.Tensor) -> float: return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad( self, exp_avg_sq_row: torch.Tensor, exp_avg_sq_col: torch.Tensor, output: torch.Tensor, ) -> None: r_factor = ( (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)) .rsqrt_() .unsqueeze(-1) ) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() torch.mul(r_factor, c_factor, out=output)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = if grad.is_sparse: raise RuntimeError( 'Adafactor does not support sparse gradients.' ) state = self.state[p] grad_shape = grad.shape factored, use_first_moment = self._get_options( group, grad_shape ) # State Initialization if len(state) == 0: state['step'] = 0 if use_first_moment: # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(grad) if factored: state['exp_avg_sq_row'] = torch.zeros( grad_shape[:-1] ).type_as(grad) state['exp_avg_sq_col'] = torch.zeros( grad_shape[:-2] + grad_shape[-1:] ).type_as(grad) else: state['exp_avg_sq'] = torch.zeros_like(grad) state['RMS'] = 0 state['step'] += 1 state['RMS'] = self._rms( lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) update = (grad ** 2) + group['eps2'][0] if factored: exp_avg_sq_row = state['exp_avg_sq_row'] exp_avg_sq_col = state['exp_avg_sq_col'] exp_avg_sq_row.mul_(beta2t).add_( update.mean(dim=-1), alpha=1.0 - beta2t ) exp_avg_sq_col.mul_(beta2t).add_( update.mean(dim=-2), alpha=1.0 - beta2t ) # Approximation of exponential moving average of square # of gradient self._approx_sq_grad( exp_avg_sq_row, exp_avg_sq_col, update ) update.mul_(grad) else: exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) torch.rsqrt(exp_avg_sq, out=update).mul_(grad) update.div_( max(1.0, self._rms(update) / group['clip_threshold']) ) update.mul_(lr) if use_first_moment: exp_avg = state['exp_avg'] exp_avg.mul_(group['beta1']).add_( update, alpha=1 - group['beta1'] ) update = exp_avg if group['weight_decay'] != 0:, alpha=-group['weight_decay'] * lr) return loss