Source code for torch_optimizer.shampoo

import torch
from torch.optim.optimizer import Optimizer

from .types import OptFloat, OptLossClosure, Params


def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor:
    # use CPU for svd for speed up
    device = matrix.device
    matrix = matrix.cpu()
    u, s, v = torch.svd(matrix)
    return (u @ s.pow_(power).diag() @ v.t()).to(device)


[docs]class Shampoo(Optimizer): r"""Implements Shampoo Optimizer Algorithm. It has been proposed in `Shampoo: Preconditioned Stochastic Tensor Optimization`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: learning rate (default: 1e-3) momentum: momentum factor (default: 0) weight_decay: weight decay (L2 penalty) (default: 0) epsilon: epsilon added to each mat_gbar_j for numerical stability (default: 1e-4) update_freq: update frequency to compute inverse (default: 1) Example: >>> import torch_optimizer as optim >>> optimizer = optim.Shampoo(model.parameters(), lr=0.01) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ https://arxiv.org/abs/1802.09568 Note: Reference code: https://github.com/moskomule/shampoo.pytorch """ def __init__( self, params: Params, lr: float = 1e-1, momentum: float = 0.0, weight_decay: float = 0.0, epsilon: float = 1e-4, update_freq: int = 1, ): if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if momentum < 0.0: raise ValueError('Invalid momentum value: {}'.format(momentum)) if weight_decay < 0.0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay) ) if epsilon < 0.0: raise ValueError('Invalid momentum value: {}'.format(momentum)) if update_freq < 1: raise ValueError('Invalid momentum value: {}'.format(momentum)) defaults = dict( lr=lr, momentum=momentum, weight_decay=weight_decay, epsilon=epsilon, update_freq=update_freq, ) super(Shampoo, self).__init__(params, defaults)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: """Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data order = grad.ndimension() original_size = grad.size() state = self.state[p] momentum = group['momentum'] weight_decay = group['weight_decay'] if len(state) == 0: state['step'] = 0 if momentum > 0: state['momentum_buffer'] = grad.clone() for dim_id, dim in enumerate(grad.size()): # precondition matrices state['precond_{}'.format(dim_id)] = group[ 'epsilon' ] * torch.eye(dim, out=grad.new(dim, dim)) state[ 'inv_precond_{dim_id}'.format(dim_id=dim_id) ] = grad.new(dim, dim).zero_() if momentum > 0: grad.mul_(1 - momentum).add_( state['momentum_buffer'], alpha=momentum ) if weight_decay > 0: grad.add_(p.data, alpha=group['weight_decay']) # See Algorithm 2 for detail for dim_id, dim in enumerate(grad.size()): precond = state['precond_{}'.format(dim_id)] inv_precond = state['inv_precond_{}'.format(dim_id)] # mat_{dim_id}(grad) grad = grad.transpose_(0, dim_id).contiguous() transposed_size = grad.size() grad = grad.view(dim, -1) grad_t = grad.t() precond.add_(grad @ grad_t) if state['step'] % group['update_freq'] == 0: inv_precond.copy_(_matrix_power(precond, -1 / order)) if dim_id == order - 1: # finally grad = grad_t @ inv_precond # grad: (-1, last_dim) grad = grad.view(original_size) else: # if not final grad = inv_precond @ grad # grad (dim, -1) grad = grad.view(transposed_size) state['step'] += 1 state['momentum_buffer'] = grad p.data.add_(grad, alpha=-group['lr']) return loss