[pytorch筆記] 調整網絡學習率


1. 為網絡的不同部分指定不同的學習率

 1 class LeNet(t.nn.Module):
 2     def __init__(self):
 3         super(LeNet, self).__init__()
 4         self.features = t.nn.Sequential(
 5             t.nn.Conv2d(3, 6, 5),
 6             t.nn.ReLU(),
 7             t.nn.MaxPool2d(2, 2),
 8             t.nn.Conv2d(6, 16, 5),
 9             t.nn.ReLU(),
10             t.nn.MaxPool2d(2, 2)
11         )
12         # 由於調整shape並不是一個class層,
13         # 所以在涉及這種操作(非nn.Module操作)需要拆分為多個模型
14         self.classifier = t.nn.Sequential(
15             t.nn.Linear(16*5*5, 120),
16             t.nn.ReLU(),
17             t.nn.Linear(120, 84),
18             t.nn.ReLU(),
19             t.nn.Linear(84, 10)
20         )
21  
22     def forward(self, x):
23         x = self.features(x)
24         x = x.view(-1, 16*5*5)
25         x = self.classifier(x)
26         return x

這里LeNet被拆解成features和classifier兩個模型來實現。在訓練時,可以為features和classifier分別指定不同的學習率。

1 model = LeNet()
2 optimizer = optim.SGD([{'params': model.features.parameters()}, 
3                        {'params': model.classifier.parameters(), 'lr': 1e-2}
4                       ], lr = 1e-5)

對於{'params': model.classifier.parameters(), 'lr': 1e-2} 被指定了特殊的學習率 'lr': 1e-2,則按照該值優化。

對於{'params': model.features.parameters()} 沒有特殊指定學習率,則使用 lr = 1e-5。

SGD的param_groups中保存着 'params', 'lr', 'momentum', 'dampening','weight_decay','nesterov'及對應值的字典。


在 CLASS torch.optim.Optimizer(params, defaults) 中,提供了 add_param_group(param_group) 函數,可以在optimizer中添加param group. 這在固定與訓練網絡模型部分,fine-tuning 訓練層部分時很實用。

2. 動態調整網絡模塊的學習率

1 for p in optimizer.param_groups:
2     p['lr'] = rate()

如果需要動態設置學習率,可以以這種方式,將關於學習率的函數賦值給參數的['lr']屬性。

還以以上定義的LeNet的optimizer為例,根據上面的定義,有兩個param_groups, 一個是model.features.parameters(), 一個是{'params': model.classifier.parameters()。

那么在for的迭代中,可以分別為這兩個param_group通過函數rate()實現動態賦予學習率的功能。


 

如果將optimizer定義為:

optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)

那么param_groups中只有一個param group,也就是網絡中各個模塊共用同一個學習率。

3. 使用pytorch封裝好的方法

https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

torch.optim.lr_scheduler中提供了一些給予epochs的動態調整學習率的方法。

https://www.jianshu.com/p/a20d5a7ed6f3 這篇blog中繪制了一些學習率方法對應的圖示。

1)torch.optim.lr_scheduler.StepLR

 1 import torch
 2 import torch.optim as optim
 3 from torch.optim import lr_scheduler
 4 from torchvision.models import AlexNet
 5 import matplotlib.pyplot as plt
 6 
 7 model = AlexNet(num_classes=2)
 8 optimizer = optim.SGD(params=model.parameters(), lr=0.05)
 9 
10 # lr_scheduler.StepLR()
11 # Assuming optimizer uses lr = 0.05 for all groups
12 # lr = 0.05     if epoch < 30
13 # lr = 0.005    if 30 <= epoch < 60
14 # lr = 0.0005   if 60 <= epoch < 90
15 
16 scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
17 plt.figure()
18 x = list(range(100))
19 y = []
20 for epoch in range(100):
21     scheduler.step()
22     lr = scheduler.get_lr()
23     print(epoch, scheduler.get_lr()[0])
24     y.append(scheduler.get_lr()[0])
25 
26 plt.plot(x, y)

 

2)torch.optim.lr_scheduler.MultiStepLR

與StepLR相比,MultiStepLR可以設置指定的區間

 1 # ---------------------------------------------------------------
 2 # 可以指定區間
 3 # lr_scheduler.MultiStepLR()
 4 #  Assuming optimizer uses lr = 0.05 for all groups
 5 # lr = 0.05     if epoch < 30
 6 # lr = 0.005    if 30 <= epoch < 80
 7 #  lr = 0.0005   if epoch >= 80
 8 print()
 9 plt.figure()
10 y.clear()
11 scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 80], 0.1)
12 for epoch in range(100):
13     scheduler.step()
14     print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
15     y.append(scheduler.get_lr()[0])
16 
17 plt.plot(x, y)
18 plt.show()

3)torch.optim.lr_scheduler.ExponentialLR

指數衰減

 1 scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
 2 print()
 3 plt.figure()
 4 y.clear()
 5 for epoch in range(100):
 6     scheduler.step()
 7     print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
 8     y.append(scheduler.get_lr()[0])
 9 
10 plt.plot(x, y)
11 plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM