如何設置PyTorch的動態學習率
本文主要涉及內容:Optimizer
、_LRScheduler
等源碼分析。
本文依舊基於PyTorch 1.1.0。
Optimizer
PyTorch提供了torch.optim.lr_scheduler
來幫助用戶改變學習率,下邊將從Optimizer
入手,看一下這個類是如何工作的。
為什么從Optimizer入手,因為無論是Adam還是SGD,都是繼承的這個類。同時,scheduler也是給所有的Optimizer服務的,所以需要用的方法都會定義在這個基類里,直接看一下這個類的屬性即可。給出Doc中的代碼鏈接。
首先是初始化方法def __init__(self, params, defaults)
,這個方法的params參數,就是我們在初始化優化器的時候傳入的網絡的參數,如Alexnet.parameters()
,而后邊所有的參數都將合並成dict參數作為這個方法的defaults。
看一下Alexnet.parameters()
中存的都是什么:
1 |
for alex in Alexnet.parameters(): |
可以看到,這里邊存的就是整個網絡的參數。
有兩種定義optimizer的方法:
1 |
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) |
如果是第一種定義的方法:在這個初始化方法中,會把這些參數先改造成[{'params': Alexnet.parameters()}]
這樣的一個長度為1的list。然后對這個list進行加工,添加上defaults中的參數,如果我們使用Alexnet來做一個例子的話,就是下邊這個樣子:
1 |
optimizer = torch.optim.Adam(Alexnet.parameters(), lr=0.001) |
如果是第二種定義的方法:因為傳入的本身就是dict的形式,所以會繼續對他進行加工,添加上后邊的參數,我們直接看療效:
1 |
optimizer = torch.optim.SGD([ |
這次的list變成了兩個元素,而且每個元素的組成和使用Adam也不一樣了,這很明顯,因為不同的優化器需要的參數不同嘛~(關於不同層的lr不同的設置這里給出官網 鏈接)
但是兩者是相似的,就是每個元素都有params和lr,這就夠了。
_LRScheduler
所有的動態修改lr的類,都是繼承的這個類,所以我們看一下這個類包含什么方法。源碼鏈接。
在初始化方法中def __init__(self, optimizer, last_epoch=-1)
,包含兩個參數,第一個參數就是我們上邊提到的optimizer的任何一個子類。第二個參數的意思是當前執行到了哪個epoch。我們不指定它的時候,雖然默認是-1,但是init中會調用一次step並設置為0。
一定要注意PyTorch的版本!我的windows上用的是1.0.1,服務器用的是1.1.0,就鬧了很多問題。就拿這個類來說,在1.0.1中是先setp()
再訓練,而1.1.0進行了更新,先訓練,然后再step()
。
當我們調用了初始化后,會給optimizer增加一個字段,看一下:
1 |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) |
新增加的
initial_lr
字段就是原始的lr。
在def step(self, epoch=None)
方法中,通常情況下我們不需要指定這個參數epoch,因為每次調用他都會增加1。在這個函數中會調用一個需要重載的方法get_lr()
,每次調用都會從這個方法中提取改變后的lr,賦值給optimizer。
這里其實我一直有個疑問的,就是scheduler的step和optimizer的step是一個什么關系,其實通過源碼,看到這里,這倆函數沒啥關系!scheduler的step只會修改lr,兩者都需要執行!
下邊看一下兩個scheduler的get_lr()
對比一下。先看一下SetpLR:
1 |
def get_lr(self): |
這個會在設置的步長的整倍數的時候將lr*gamma。
而ExponentialLR則會在每輪結束的時候都進行乘gamma的操作,這個減小也真的是指數倍的。
1 |
def get_lr(self): |
Demo
1 |
scheduler = StepLR(optimizer, step_size=30, gamma=0.1) |
- Post link: https://yichengsu.github.io/2019/08/how-to-set-lr-in-pytorch/
- Copyright Notice: All articles in this blog are licensed under BY-NC-SA unless stating additionally.
optimizer.param_groups:是長度為2的list,其中的元素是2個字典;
optimizer.param_groups[0]:長度為6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]這6個參數
optimizer.param_groups[1]:表示優化器的狀態的一個字典
————————————————
版權聲明:本文為CSDN博主「Wanderer001」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/weixin_36670529/article/details/107531773