Autograd: 自動求導


Pytorch中神經網絡包中最核心的是autograd包,我們先來簡單地學習它,然后訓練我們第一個神經網絡。

autograd包為所有在tensor上的運算提供了自動求導的支持,這是一個逐步運行的框架,也就意味着后向傳播過程是按照你的代碼定義的,並且單個循環可以不同

我們通過一些簡單例子來了解

Tensor

torch.tensor是這個包的基礎類,如果你設置.requires_grads為True,它就會開始跟蹤上面的所有運算。如果你做完了運算使用.backward(),所有的梯度就會自動運算,tesor的梯度將會累加到.grad這個屬性。

若要停止tensor的歷史紀錄,可以使用.detch()將它從歷史計算中分離出來,防止未來的計算被跟蹤。

 為了防止追蹤歷史(並且使用內存),你也可以將代碼塊包含在with torch.no_grad():中。這對於評估模型時是很有用的,因為模型也許擁有可訓練的參數使用了requires_grad=True,但是這種情況下我們不需要梯度。

還有一個類對autograd的實現非常重要,——Function

Tensor和Function是相互關聯的並一起組成非循環圖,它編碼了所有計算的歷史,每個tensor擁有一個屬性.grad_fn,該屬性引用已創建tensor的Function。(除了用戶自己創建的tensor,它們的.grad_fn為None)。

如果你想計算導數,可以在一個Tensor上調用.backward()。如果Tensor是一個標量(也就是只包含一個元素數據),你不需要為backward指明任何參數,但是擁有多個元素的情況下,你需要指定一個匹配維度的gradient參數。

import torch

創建一個tensor並設置rquires_grad=True來追蹤上面的計算

x=torch.ones(2,2,requires_grad=True)
print(x)

out:
tensor([[ 1.,  1.],
        [ 1.,  1.]])

執行一個tensor運算

y=x+2
print(y)

out:
tensor([[ 3., 3.],
[ 3., 3.]])

y是通過運算的結果建立的,所以它有grad_fn

print(y.grad_fn)

out:
<AddBackward0 object at 0x000001EDFE054D30>

在y上進行進一步的運算

z=y*y*3
out=z.mean()
print(z,out)\

out:
tensor([[ 27.,  27.],
        [ 27.,  27.]]) tensor(27.)

.requires_grad_(...)可以用內建方式改變tensor的requires_grad標志位。如果沒有給定,輸入標志默認為False

a=torch.randn(2,2)
a=((a*3)/(a-1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b=(a*a).sum()
print(b.grad_fn)

out:
False
True
<SumBackward0 object at 0x000001EDFE054940>

 

Gradients

我們開始反向傳播,因為out包含單一標量,out.backward()相當於out.backward(torch.tensor(1)).

out.backward()

打印梯度d(out)/dx

print(x.grad)

out:
tensor([[ 4.5000, 4.5000],
[ 4.5000, 4.5000]])

你應該得到一個4.5的矩陣。可以簡單手動計算一下這一結果。

你可以使用autograd做許多瘋狂的事情

x=torch.randn(3,requires_grad=True)
y=x*2
while y.data.norm()<1000:
    y=y*2
print(y)

out:
tensor([  980.8958,  1180.4403,   614.2102])
gradients=torch.tensor([0.1,1.0,0.0001],dtype=torch.float)
y.backward(gradients)
print(x.grad)

out:
tensor([  102.4000,  1024.0000,     0.1024])

你可以將語句包含在with torch.no_grad()從Tensor的歷史停止自動求導

print(x.requires_grad)
print((x**2).requires_grad)
with torch.no_grad():
    print((x**2).requires_grad)

out:
True
True
False

 

   


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM