| 本節講述Pytorch中torch.optim優化器包,學習率、參數Momentum動量的含義,以及常用的幾類優化器。【Latex公式采用在線編碼器】 優化器概念:管理並更新模型所選中的網絡參數,使得模型輸出更加接近真實標簽。 |
1. Optimizer基本屬性
(1)如何創建一個優化器
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 第一項也可自定義參數,用list封裝
# 后面介紹的基本方法,都是利用optimizer.方法
(2)繼承Optimizer父類
所有的optim中的優化器都繼承Optimizer父類,即:
class Optimizer(object):
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults # 1 保存優化器本身的參數,例如
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict) #2
self.param_groups = [] #3
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group) # 調用add_param_group函數,將default優化器本身參數,送入param_groups中
由上式代碼注釋#可知,重要參數如下:
-
self.defaults:優化器本身參數,如學習率、動量等等
-
self.state:參數緩存,如動量緩存
-
self.param_groups:管理的參數組,注意這里是list(dict)形式,即列表中字典。
例如:<class 'list'>: [{'params': [網絡參數], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
注意:這里模型中的參數(如W)與param_groups中保存的W,地址相同。
2.Optimizer的基本方法
(1)optimizer.zero_grad()
清空所管理的網絡參數的梯度
class Optimizer(object):
def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for group in self.param_groups: # 對於self.param_groups的list中字典key='params'對應的value
for p in group['params']:
if p.grad is not None:
p.grad.detach_() # 脫離原來的計算圖,被計算機捕捉到
p.grad.zero_()
(2)optimizer.step()
執行一步更新,根據對應的梯度下降策略。
(3)optimizer.add_param_group()
添加參數組,經常用於finetune,又例如設置兩部分參數,e.g. 網絡分為:特征提取層+全連接分類層,設置兩組優化參數。
class Optimizer(object):
def add_param_group(self, param_group):
"""
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
params = param_group['params']
param_set = set()
for group in self.param_groups:
param_set.update(set(group['params']))
self.param_groups.append(param_group)
同一個優化器,添加新的優化參數:
weight = torch.randn((2, 2), requires_grad=True)
optimizer = optim.SGD([weight], lr=0.1)
print('添加之后未添加之前:{}'.format(optimizer.param_groups))
'''
添加之后未添加之前:[{'params': [tensor([[ 0.4523, 0.2895],
[-0.4283, 1.0688]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
'''
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001})
print("添加之后{}".format(optimizer.param_groups))
'''
添加之后[{'params': [tensor([[ 0.4523, 0.2895],
[-0.4283, 1.0688]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-1.0346, 1.2396, -1.4738],
[ 0.8029, -1.1723, 0.0783],
[ 0.7809, 0.4156, 0.3127]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
'''
可以看到添加之后,optimizer.param_groups list中含有兩個字典,一個字典是之前的參數,另一個字典是新添加的一系列優化器參數
(4)optimizer.state_dict()
獲取當前優化器的一系列信息參數。由代碼可知,返回的是字典,兩個key:'state'和'param_groups'
class Optimizer(object):
def state_dict(self):
...
...
return {
'state': packed_state,
'param_groups': param_groups,
}
self.state:參數緩存,如動量緩存,當網絡沒有經過optimizer.step(),即沒有根據loss.backward()得到的梯度去更新網絡參數時,state為空:
print(optimizer.state_dict())
'''
{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140306069859856]}]}
'''
當更新之后,'state'將保存'params'中value的地址以及{'momentun_buffer':tensor()}動量緩存,用於后續斷點恢復。
(5)optimizer.load_state_dict()
加載保存的狀態信息字典
'''保存優化器狀態信息'''
torch.save(optimizer.state_dict(), os.path.join(address, "name.pkl"))
'''加載優化器狀態信息'''
state_dict = torch.load(os.path.join(address, "name.pkl"))
optimizer.load_state_dict(state_dict)
3.學習率lr
學習率可以看作是對梯度的縮小因子,用來控制梯度更新的步伐:
-
lr不能過大(易loss激增);
-
lr不能過小(收斂較慢);
-
當設置lr適當小時,如0.01,此時可通過增加網絡訓練時間,進行彌補;
4.動量Momentum
(1)指數加權平均
結合當前梯度與上一時刻更新的信息,來更新當前梯度信息。Momentum 梯度下降法 可追溯到指數加權平均:
其中 \(\theta _{t}\) 為當前時刻的參數,因為 \(\beta < 1\) ,從上述公式可知,距離當前t時刻越遠的時刻參數,權重越小,對t時刻影響越小。
(2)Pytroch中的動量計算
可以看到, 當 \(Momentum\) 太大時,由於受到前面時刻梯度線性影響,會有一定的震盪。
5.optim.SGD隨機梯度下降
optimizer = optim.SGD(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False)
6.torch.optim下10種優化器
下次來補充啦😉!
