Welcome to pytorch-optimizer’s documentation!¶
torch-optimizer – collection of optimizers for PyTorch.
Simple example¶
import torch_optimizer as optim
# model = ...
optimizer = optim.DiffGrad(model.parameters(), lr=0.001)
optimizer.step()
Supported Optimizers¶
https://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf |
|
Ranger |
|
RangerQH |
|
RangerVA |
|
https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization |
Contents¶
Available Optimizers¶
AccSGD¶
-
class
torch_optimizer.
AccSGD
(params, lr=0.001, kappa=1000.0, xi=10.0, small_const=0.7, weight_decay=0)[source]¶ Implements AccSGD algorithm.
It has been proposed in On the insufficiency of existing momentum schemes for Stochastic Optimization and Accelerating Stochastic Gradient Descent For Least Squares Regression
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)kappa (
float
) – ratio of long to short step (default: 1000)xi (
float
) – statistical advantage parameter (default: 10)small_const (
float
) – any value <=1 (default: 0.7)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.AccSGD(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/rahulkidambi/AccSGD
AdaBound¶
-
class
torch_optimizer.
AdaBound
(params, lr=0.001, betas=0.9, 0.999, final_lr=0.1, gamma=0.001, eps=1e-08, weight_decay=0, amsbound=False)[source]¶ Implements AdaBound algorithm.
It has been proposed in Adaptive Gradient Methods with Dynamic Bound of Learning Rate.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))final_lr (
float
) – final (SGD) learning rate (default: 0.1)gamma (
float
) – convergence speed of the bound functions (default: 1e-3)eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)amsbound (
bool
) – whether to use the AMSBound variant of this algorithm
Example
>>> import torch_optimizer as optim >>> optimizer = optim.AdaBound(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/Luolc/AdaBound
AdaMod¶
-
class
torch_optimizer.
AdaMod
(params, lr=0.001, betas=0.9, 0.999, beta3=0.999, eps=1e-08, weight_decay=0)[source]¶ Implements AdaMod algorithm.
It has been proposed in Adaptive and Momental Bounds for Adaptive Learning Rate Methods.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))beta3 (
float
) – smoothing coefficient for adaptive learning rates (default: 0.9999)eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.AdaMod(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/lancopku/AdaMod
Adafactor¶
-
class
torch_optimizer.
Adafactor
(params, lr=None, eps2=1e-30, 0.001, clip_threshold=1.0, decay_rate=- 0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False)[source]¶ Implements Adafactor algorithm.
It has been proposed in: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
Optional
[float
]) – external learning rate (default: None)eps2 (
Tuple
[float
,float
]) – regularization constans for square gradient and parameter scale respectively (default: (1e-30, 1e-3))clip_threshold (
float
) – threshold of root mean square of final gradient update (default: 1.0)decay_rate (
float
) – coefficient used to compute running averages of square gradient (default: -0.8)beta1 (
Optional
[float
]) – coefficient used for computing running averages of gradient (default: None)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)scale_parameter (
bool
) – if true, learning rate is scaled by root mean square of parameter (default: True)relative_step (
bool
) – if true, time-dependent learning rate is computed instead of external learning rate (default: True)warmup_init (
bool
) – time-dependent learning rate computation depends on whether warm-up initialization is being used (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.Adafactor(model.parameters()) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py # noqa
AdamP¶
-
class
torch_optimizer.
AdamP
(params, lr=0.001, betas=0.9, 0.999, eps=1e-08, weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False)[source]¶ Implements AdamP algorithm.
It has been proposed in Slowing Down the Weight Norm Increase in Momentum-based Optimizers
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)delta (
float
) – threhold that determines whether a set of parameters is scale invariant or not (default: 0.1)wd_ratio (
float
) – relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters (default: 0.1)nesterov (
bool
) – enables Nesterov momentum (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.AdamP(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/clovaai/AdamP
AggMo¶
-
class
torch_optimizer.
AggMo
(params, lr=0.001, betas=0.0, 0.9, 0.99, weight_decay=0)[source]¶ Implements Aggregated Momentum Gradient Descent.
It has been proposed in Aggregated Momentum: Stability Through Passive Damping
Example
>>> import torch_optimizer as optim >>> optimizer = optim.AggMo(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/AtheMathmo/AggMo/blob/master/aggmo.py # noqa
DiffGrad¶
-
class
torch_optimizer.
DiffGrad
(params, lr=0.001, betas=0.9, 0.999, eps=1e-08, weight_decay=0.0)[source]¶ Implements DiffGrad algorithm.
It has been proposed in DiffGrad: An Optimization Method for Convolutional Neural Networks.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/shivram1987/diffGrad
Lamb¶
-
class
torch_optimizer.
Lamb
(params, lr=0.001, betas=0.9, 0.999, eps=1e-06, weight_decay=0, clamp_value=10, adam=False, debias=False)[source]¶ Implements Lamb algorithm.
It has been proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)clamp_value (
float
) – clamp weight_norm in (0,clamp_value) (default: 10) set to a high value to avoid it (e.g 10e3)adam (
bool
) – always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. (default: False)debias (
bool
) – debias adam by (1 - beta**step) (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.Lamb(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/cybertronai/pytorch-lamb
NovoGrad¶
-
class
torch_optimizer.
NovoGrad
(params, lr=0.001, betas=0.95, 0, eps=1e-08, weight_decay=0, grad_averaging=False, amsgrad=False)[source]¶ Implements Novograd optimization algorithm.
It has been proposed in Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.95, 0))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)grad_averaging (
bool
) – gradient averaging (default: False)amsgrad (
bool
) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.Yogi(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> scheduler = StepLR(optimizer, step_size=1, gamma=0.7) >>> optimizer.step() >>> scheduler.step()
Note
Reference code: https://github.com/NVIDIA/DeepLearningExamples
PID¶
-
class
torch_optimizer.
PID
(params, lr=0.001, momentum=0.0, dampening=0, weight_decay=0.0, integral=5.0, derivative=10.0)[source]¶ Implements PID optimization algorithm.
It has been proposed in A PID Controller Approach for Stochastic Optimization of Deep Networks.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)momentum (
float
) – momentum factor (default: 0.0)weight_decay (
float
) – weight decay (L2 penalty) (default: 0.0)dampening (
float
) – dampening for momentum (default: 0.0)derivative (
float
) – D part of the PID (default: 10.0)integral (
float
) – I part of the PID (default: 5.0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.PID(model.parameters(), lr=0.001, momentum=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/tensorboy/PIDOptimizer
QHAdam¶
-
class
torch_optimizer.
QHAdam
(params, lr=0.001, betas=0.9, 0.999, nus=1.0, 1.0, weight_decay=0.0, decouple_weight_decay=False, eps=1e-08)[source]¶ Implements the QHAdam optimization algorithm.
It has been proposed in Adaptive methods for Nonconvex Optimization.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))nus (
Tuple
[float
,float
]) – immediate discount factors used to estimate the gradient and its square (default: (1.0, 1.0))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)decouple_weight_decay (
bool
) – whether to decouple the weight decay from the gradient-based optimization step (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.QHAdam(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/facebookresearch/qhoptim
QHM¶
-
class
torch_optimizer.
QHM
(params, lr=0.001, momentum=0.0, nu=0.7, weight_decay=0.0, weight_decay_type='grad')[source]¶ -
DIRECT
= 'direct'¶ Implements quasi-hyperbolic momentum (QHM) optimization algorithm.
It has been proposed in Quasi-hyperbolic momentum and Adam for deep learning.
- Parameters
params – iterable of parameters to optimize or dicts defining parameter groups
lr – learning rate (default: 1e-3)
momentum – momentum factor (\(\beta\) from the paper)
nu – immediate discount factor (\(\nu\) from the paper)
weight_decay – weight decay (L2 regularization coefficient, times two) (default: 0.0)
weight_decay_type – method of applying the weight decay:
"grad"
for accumulation in the gradient (same astorch.optim.SGD
) or"direct"
for direct application to the parameters (default:"grad"
)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.QHM(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/facebookresearch/qhoptim
-
RAdam¶
-
class
torch_optimizer.
RAdam
(params, lr=0.001, betas=0.9, 0.999, eps=1e-08, weight_decay=0)[source]¶ Implements RAdam optimization algorithm.
It has been proposed in On the Variance of the Adaptive Learning Rate and Beyond.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.RAdam(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/LiyuanLucasLiu/RAdam
SGDP¶
-
class
torch_optimizer.
SGDP
(params, lr=0.001, momentum=0, dampening=0, eps=1e-08, weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False)[source]¶ Implements SGDP algorithm.
It has been proposed in Slowing Down the Weight Norm Increase in Momentum-based Optimizers
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)momentum (
float
) – momentum factor (default: 0)dampening (
float
) – dampening for momentum (default: 0)eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)delta (
float
) – threhold that determines whether a set of parameters is scale invariant or not (default: 0.1)wd_ratio (
float
) – relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters (default: 0.1)nesterov (
bool
) – enables Nesterov momentum (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.SGDP(model.parameters(), lr=0.1) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/clovaai/AdamP
SGDW¶
-
class
torch_optimizer.
SGDW
(params, lr=0.001, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False)[source]¶ Implements SGDW algorithm.
It has been proposed in Decoupled Weight Decay Regularization.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)momentum (
float
) – momentum factor (default: 0)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)dampening (
float
) – dampening for momentum (default: 0)nesterov (
bool
) – enables Nesterov momentum (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.SGDW(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/pytorch/pytorch/pull/22466
Shampoo¶
-
class
torch_optimizer.
Shampoo
(params, lr=0.1, momentum=0.0, weight_decay=0.0, epsilon=0.0001, update_freq=1)[source]¶ Implements Shampoo Optimizer Algorithm.
It has been proposed in Shampoo: Preconditioned Stochastic Tensor Optimization.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-3)momentum (
float
) – momentum factor (default: 0)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)epsilon (
float
) – epsilon added to each mat_gbar_j for numerical stability (default: 1e-4)update_freq (
int
) – update frequency to compute inverse (default: 1)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.Shampoo(model.parameters(), lr=0.01) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/moskomule/shampoo.pytorch
SWATS¶
-
class
torch_optimizer.
SWATS
(params, lr=0.001, betas=0.9, 0.999, eps=0.001, weight_decay=0, amsgrad=False, nesterov=False)[source]¶ Implements SWATS Optimizer Algorithm. It has been proposed in Improving Generalization Performance by Switching from Adam to SGD.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-2)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-3)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)amsgrad (
bool
) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False)nesterov (
bool
) – enables Nesterov momentum (default: False)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.SWATS(model.parameters(), lr=0.01) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/Mrpatekful/swats
Yogi¶
-
class
torch_optimizer.
Yogi
(params, lr=0.01, betas=0.9, 0.999, eps=0.001, initial_accumulator=1e-06, weight_decay=0)[source]¶ Implements Yogi Optimizer Algorithm. It has been proposed in Adaptive methods for Nonconvex Optimization.
- Parameters
params (
Union
[Iterable
[Tensor
],Iterable
[Dict
[str
,Any
]]]) – iterable of parameters to optimize or dicts defining parameter groupslr (
float
) – learning rate (default: 1e-2)betas (
Tuple
[float
,float
]) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))eps (
float
) – term added to the denominator to improve numerical stability (default: 1e-8)initial_accumulator (
float
) – initial values for first and second moments (default: 1e-6)weight_decay (
float
) – weight decay (L2 penalty) (default: 0)
Example
>>> import torch_optimizer as optim >>> optimizer = optim.Yogi(model.parameters(), lr=0.01) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
Reference code: https://github.com/4rtemi5/Yogi-Optimizer_Keras
Examples of pytorch-optimizer usage¶
Below is a list of examples from pytorch-optimizer/examples
Every example is a correct tiny python program.
Basic Usage¶
Simple example that shows how to use library with MNIST dataset.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
import torch_optimizer as optim
from torchvision import datasets, transforms, utils
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(conf, model, device, train_loader, optimizer, epoch, writer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % conf.log_interval == 0:
loss = loss.item()
idx = batch_idx + epoch * (len(train_loader))
writer.add_scalar('Loss/train', loss, idx)
print(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss,
)
)
def test(conf, model, device, test_loader, epoch, writer):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
fmt = '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
print(
fmt.format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)
writer.add_scalar('Accuracy', correct, epoch)
writer.add_scalar('Loss/test', test_loss, epoch)
def prepare_loaders(conf, use_cuda=False):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=conf.batch_size,
shuffle=True,
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=conf.test_batch_size,
shuffle=True,
**kwargs,
)
return train_loader, test_loader
class Config:
def __init__(
self,
batch_size: int = 64,
test_batch_size: int = 1000,
epochs: int = 15,
lr: float = 0.01,
gamma: float = 0.7,
no_cuda: bool = True,
seed: int = 42,
log_interval: int = 10,
):
self.batch_size = batch_size
self.test_batch_size = test_batch_size
self.epochs = epochs
self.lr = lr
self.gamma = gamma
self.no_cuda = no_cuda
self.seed = seed
self.log_interval = log_interval
def main():
conf = Config()
log_dir = 'runs/mnist_custom_optim'
print('Tensorboard: tensorboard --logdir={}'.format(log_dir))
with SummaryWriter(log_dir) as writer:
use_cuda = not conf.no_cuda and torch.cuda.is_available()
torch.manual_seed(conf.seed)
device = torch.device('cuda' if use_cuda else 'cpu')
train_loader, test_loader = prepare_loaders(conf, use_cuda)
model = Net().to(device)
# create grid of images and write to tensorboard
images, labels = next(iter(train_loader))
img_grid = utils.make_grid(images)
writer.add_image('mnist_images', img_grid)
# visualize NN computation graph
writer.add_graph(model, images)
# custom optimizer from torch_optimizer package
optimizer = optim.DiffGrad(model.parameters(), lr=conf.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=conf.gamma)
for epoch in range(1, conf.epochs + 1):
train(conf, model, device, train_loader, optimizer, epoch, writer)
test(conf, model, device, test_loader, epoch, writer)
scheduler.step()
for name, param in model.named_parameters():
writer.add_histogram(name, param, epoch)
writer.add_histogram('{}.grad'.format(name), param.grad, epoch)
if __name__ == '__main__':
main()
Contributing¶
Running Tests¶
Thanks for your interest in contributing to pytorch-optimizer
, there are multiple
ways and places you can contribute.
Fist of all just clone repository:
$ git clone git@github.com:jettify/pytorch-optimizer.git
Create virtualenv with python3.5 (older version are not supported). For example using virtualenvwrapper commands could look like:
$ cd pytorch-optimizer
$ mkvirtualenv --python=`which python3.7` pytorch-optimizer
After that please install libraries required for development:
$ pip install -r requirements-dev.txt
$ pip install -e .
Congratulations, you are ready to run the test suite:
$ make cov
To run individual use following command:
$ py.test -sv tests/test_basic.py -k test_name
Reporting an Issue¶
If you have found issue with pytorch-optimizer please do not hesitate to file an issue on the GitHub project. When filing your issue please make sure you can express the issue with a reproducible test case.
When reporting an issue we also need as much information about your environment that you can include. We never know what information will be pertinent when trying narrow down the issue. Please include at least the following information:
Version of pytorch-optimizer, python.
Version PyTorch if installed.
Version or CUDA if installed.
Platform you’re running on (OS X, Linux).