pytorch中的backward


這個函數的作用是反向傳播計算梯度的。

這個只有標量才能直接使用 backward(),如果使用自定義的函數,得到的不是標量,則backward()時需要傳入 grad_variable 參數。

torch.tensor是autograd包的基礎類,如果你設置tensor的requires_grads為True,就會開始跟蹤這個tensor上面的所有運算,如果你做完運算后使用tensor.backward(),所有的梯度就會自動運算,tensor的梯度將會累加到它的.grad屬性里面去。
如果沒有進行tensor.backward()的話,梯度值將會是None,因此loss.backward()要寫在optimizer.step()之前。

 

在用pytorch訓練模型時,通常會在遍歷epochs的過程中依次用到optimizer.zero_grad(),loss.backward()和optimizer.step()三個函數,如下所示:

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
 
for epoch in range(1, epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        output= model(inputs)
        loss = criterion(output, labels)
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

總得來說,這三個函數的作用是先將梯度歸零(optimizer.zero_grad()),然后反向傳播計算得到每個參數的梯度值(loss.backward()),最后通過梯度下降執行一步參數更新(optimizer.step())

接下來將通過源碼分別理解這三個函數的具體實現過程。在此之前,先簡要說明一下函數中常見的參數變量:

param_groups:Optimizer類在實例化時會在構造函數中創建一個param_groups列表,列表中有num_groups個長度為6的param_group字典(num_groups取決於你定義optimizer時傳入了幾組參數),每個param_group包含了 ['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'] 這6組鍵值對。

param_group['params']:由傳入的模型參數組成的列表,即實例化Optimizer類時傳入該group的參數,如果參數沒有分組,則為整個模型的參數model.parameters(),每個參數是一個torch.nn.parameter.Parameter對象。

 

一、optimizer.zero_grad():

    def zero_grad(self):
        r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()

optimizer.zero_grad()函數會遍歷模型的所有參數,通過p.grad.detach_()方法截斷反向傳播的梯度流,再通過p.grad.zero_()函數將每個參數的梯度值設為0,即上一次的梯度記錄被清空。

因為訓練的過程通常使用mini-batch方法,所以如果不將梯度清零的話,梯度會與上一個batch的數據相關,因此該函數要寫在反向傳播和梯度下降之前。

 

二、loss.backward():

PyTorch的反向傳播(即tensor.backward())是通過autograd包來實現的,autograd包會根據tensor進行過的數學運算來自動計算其對應的梯度。

具體來說,torch.tensor是autograd包的基礎類,如果你設置tensor的requires_grads為True,就會開始跟蹤這個tensor上面的所有運算,如果你做完運算后使用tensor.backward(),所有的梯度就會自動運算,tensor的梯度將會累加到它的.grad屬性里面去。

更具體地說,損失函數loss是由模型的所有權重w經過一系列運算得到的,若某個w的requires_grads為True,則w的所有上層參數(后面層的權重w)的.grad_fn屬性中就保存了對應的運算,然后在使用loss.backward()后,會一層層的反向傳播計算每個w的梯度值,並保存到該w的.grad屬性中。

如果沒有進行tensor.backward()的話,梯度值將會是None,因此loss.backward()要寫在optimizer.step()之前。

 

三、optimizer.step():

以SGD為例,torch.optim.SGD().step()源碼如下:

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
 
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
 
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
 
                p.data.add_(-group['lr'], d_p)
 
        return loss

step()函數的作用是執行一次優化步驟,通過梯度下降法來更新參數的值。因為梯度下降是基於梯度的,所以在執行optimizer.step()函數前應先執行loss.backward()函數來計算梯度。

注意:optimizer只負責通過梯度下降進行優化,而不負責產生梯度,梯度是tensor.backward()方法產生的。

參考:https://www.cnblogs.com/Thinker-pcw/p/9630367.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM