torch.optim.SGD
class torch.optim.SGD(params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)
功能:
可實現SGD優化算法,帶動量SGD優化算法,帶NAG(Nesterov accelerated gradient)動量SGD優化算法,並且均可擁有weight_decay項
參數:
params(iterable)- 參數組(參數組的概念請查看 3.2 優化器基類:Optimizer),優化器要管理的那部分參數。
lr(float)- 初始學習率,可按需隨着訓練過程不斷調整學習率。
momentum(float)- 動量,通常設置為0.9,0.8
dampening(float)- dampening for momentum ,暫時不了其功能,在源碼中是這樣用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening必須為 0.
weight_decay(float)- 權值衰減系數,也就是L2正則項的系數
nesterov(bool)- bool選項,是否使用NAG(Nesterov accelerated gradient)
注意事項:
pytroch中使用SGD十分需要注意的是,更新公式與其他框架略有不同!
pytorch中是這樣的:
v=ρ∗v+g
p=p−lr∗v = p - lr∗ρ∗v - lr∗g
其他框架:
v=ρ∗v+lr∗g
p=p−v = p - ρ∗v - lr∗g
ρ是動量,v是速率,g是梯度,p是參數,其實差別就是在ρ∗v這一項,pytorch中將此項也乘了一個學習率。
torch.optim.ASGD
class torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
功能:
ASGD也成為SAG,均表示隨機平均梯度下降(Averaged Stochastic Gradient Descent),簡單地說ASGD就是用空間換時間的一種SGD,詳細可參看論文:riejohnson.com/rie/stog
參數:
params(iterable)- 參數組(參數組的概念請查看 3.1 優化器基類:Optimizer),優化器要優化的那些參數。
lr(float)- 初始學習率,可按需隨着訓練過程不斷調整學習率。
lambd(float)- 衰減項,默認值1e-4。
alpha(float)- power for eta update ,默認值0.75。
t0(float)- point at which to start averaging,默認值1e6。
weight_decay(float)- 權值衰減系數,也就是L2正則項的系數。
torch.optim.Rprop
class torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))
功能:
實現Rprop優化方法(彈性反向傳播),優化方法原文《Martin Riedmiller und Heinrich Braun: Rprop - A Fast Adaptive Learning Algorithm. Proceedings of the International Symposium on Computer and Information Science VII, 1992》
該優化方法適用於full-batch,不適用於mini-batch,因而在min-batch大行其道的時代里,很少見到。
torch.optim.Adagrad
class torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
功能:
實現Adagrad優化方法(Adaptive Gradient),Adagrad是一種自適應優化方法,是自適應的為各個參數分配不同的學習率。這個學習率的變化,會受到梯度的大小和迭代次數的影響。梯度越大,學習率越小;梯度越小,學習率越大。缺點是訓練后期,學習率過小,因為Adagrad累加之前所有的梯度平方作為分母。
詳細公式請閱讀:Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
torch.optim.Adadelta
class torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
功能:
實現Adadelta優化方法。Adadelta是Adagrad的改進。Adadelta分母中采用距離當前時間點比較近的累計項,這可以避免在訓練后期,學習率過小。
詳細公式請閱讀:arxiv.org/pdf/1212.5701
torch.optim.RMSprop
class torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
功能:
實現RMSprop優化方法(Hinton提出),RMS是均方根(root meam square)的意思。RMSprop和Adadelta一樣,也是對Adagrad的一種改進。RMSprop采用均方根作為分母,可緩解Adagrad學習率下降較快的問題。並且引入均方根,可以減少擺動,詳細了解可讀:cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
torch.optim.Adam(AMSGrad)
class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
功能:
實現Adam(Adaptive Moment Estimation))優化方法。Adam是一種自適應學習率的優化方法,Adam利用梯度的一階矩估計和二階矩估計動態的調整學習率。吳老師課上說過,Adam是結合了Momentum和RMSprop,並進行了偏差修正。
功能:
amsgrad- 是否采用AMSGrad優化方法,asmgrad優化方法是針對Adam的改進,通過添加額外的約束,使學習率始終為正值。(AMSGrad,ICLR-2018 Best-Pper之一,《On the convergence of Adam and Beyond》)。
詳細了解Adam可閱讀,Adam: A Method for Stochastic Optimization(Adam: A Method for Stochastic Optimization)。
torch.optim.Adamax
class torch.optim.Adamax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
功能:
實現Adamax優化方法。Adamax是對Adam增加了一個學習率上限的概念,所以也稱之為Adamax。
詳細了解可閱讀,Adam: A Method for Stochastic Optimization(arxiv.org/abs/1412.6980)(沒錯,就是Adam論文中提出了Adamax)。
torch.optim.SparseAdam
class torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
功能:
針對稀疏張量的一種“閹割版”Adam優化方法。
only moments that show up in the gradient get updated, and only those portions of the gradient get applied to the parameters
torch.optim.LBFGS
class torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None)
**功能:**
實現L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)優化方法。L-BFGS屬於擬牛頓算法。L-BFGS是對BFGS的改進,特點就是節省內存。
使用注意事項:
1.This optimizer doesn’t support per-parameter options and parameter groups (there can be only one).
Right now all parameters have to be on a single device. This will be improved in the future.(2018-10-07)