從頭學pytorch(二) 自動求梯度


PyTorch提供的autograd包能夠根據輸⼊和前向傳播過程⾃動構建計算圖,並執⾏反向傳播。

Tensor

Tensor的幾個重要屬性或方法

  • .requires_grad 設為true的話,tensor將開始追蹤在其上的所有操作
  • .backward()完成梯度計算
  • .grad屬性 計算的梯度累積到.grad屬性
  • .detach()解除對一個tensor上操作的追蹤,或者用with torch.no_grad()將不想被追蹤的操作代碼塊包裹起來.
  • .grad_fn屬性 該屬性即創建Tensor的Function類的類型,即該Tensor是由什么運算得來的

幾個例子具體地解釋一下:

import torch
x = torch.ones(2, 2, requires_grad=True)
print(x)
print(x.grad_fn)

y = x+2
print(y)
print(y.grad_fn)

z = y*y*3
out=z.mean()
print(z,out)

輸出

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
None
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward>)
<AddBackward object at 0x0000018752434B70>


tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward>) tensor(27., grad_fn=<MeanBackward1>)

y由加法得到,所以y.grad_fn= ,x直接創建,其x.grad_fn=None. x這種直接創建的又稱為葉子節點.

print(x.is_leaf, y.is_leaf) # True False

可以用.requires_grad_()來用in-place的方式改變requires_grad屬性.

a = torch.randn(2, 2) # 缺失情況下默認 requires_grad = False
a = ((a * 3) / (a - 1))
print(a.requires_grad) # False
a.requires_grad_(True)
print(a.requires_grad) # True
b = (a * a).sum()
print(b.grad_fn)

輸出

False
True
<SumBackward0 object at 0x0000018752434D30>

梯度

所計算的梯度都是結果變量關於創建變量的梯度。
比如對:

x = torch.ones(2, 2, requires_grad=True)
print(x)
print(x.grad_fn)

y = x+2
print(y)
print(y.grad_fn)

z = y*3
z.backward(torch.ones_like(z))
print(y.grad) #None  
print(x.grad)

輸出

None
tensor([[3., 3.],
        [3., 3.]])

上述代碼相當於創建了一個動態圖,其中x是我們創建的變量,y和z都是因為x的改變會改變的結果變量. 所以在這個動態圖里能夠求的梯度只有\(\frac{\partial{z}}{\partial{x}}\),\(\frac{\partial{y}}{\partial{x}}\)

為什么l.backward(gradient)需要傳入一個和l同樣形狀的gradient?
對於l.backward()而言,當l是標量時,可以不傳參,相當於l.backward(torch.tensor(1.))
當l不是標量時,需要傳入一個和l同shape的gradient。

假設 x 經過一番計算得到 y,那么 y.backward(w) 求的不是 y 對 x 的導數,而是 l = torch.sum(y*w) 對 x 的導數。w 可以視為 y 的各分量的權重,也可以視為遙遠的損失函數 l 對 y 的偏導數(這正是函數說明文檔的含義)。特別地,若 y 為標量,w 取默認值 1.0,才是按照我們通常理解的那樣,求 y 對 x 的導數

簡單地說就是,張量對張量沒法求導,所以我們需要人為地定義一個w,把一個非標量的Tensor通過torch.sum(y*w)的形式轉換成標量。我們自己定義的這個w的不同,當然最后得到的梯度就不同.通常定義為全1.也就是認為Tensor y中的每一個變量的重要性是等同的.

另一個角度的理解就是,y是一個tensor,是一個向量,有N個標量,這每一個標量都與x有關。對這N個標量我們需要賦以不同的權重,以顯示y中每一個標量受到x影響的程度.

比如對

import torch
x = torch.ones(2, 2, requires_grad=True)
print(x)
print(x.grad_fn)

y = x+2
print(y)
print(y.grad_fn)

z = y*3
print(z.shape)
w1=torch.Tensor([[1,2],[1,2]])
z.backward([w1])
print(x.grad)

x.grad.data.zero_()
w2=torch.Tensor([[1,1],[1,1]])
z.backward([w2])
print(x.grad)

輸出

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
None
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward>)
<AddBackward object at 0x00000187524A6828>
torch.Size([2, 2])
tensor([[3., 6.],
        [3., 6.]])
tensor([[3., 3.],
        [3., 3.]])

對w1和w2而言,z.backward()以后x.grad是不同的。
注意:梯度是累加的,所以第二次計算之前我們做了清零的操作:x.grad.data.zero_()

可以參考:
https://zhuanlan.zhihu.com/p/29923090
https://www.cnblogs.com/zhouyang209117/p/11023160.html

再來看看中斷梯度追蹤的例子:

x = torch.tensor(1.0, requires_grad=True)
y1 = x ** 2 
with torch.no_grad():
    y2 = x ** 3
y3 = y1 + y2
    
print(x.requires_grad)
print(y1, y1.requires_grad) # True
print(y2, y2.requires_grad) # False
print(y3, y3.requires_grad) # True

輸出:

True
tensor(1., grad_fn=<PowBackward0>) True
tensor(1.) False
tensor(2., grad_fn=<ThAddBackward>) True

反向傳播,求梯度

y3.backward()
print(x.grad)

輸出:

tensor(2.)

為什么是2呢?$ y_3 = y_1 + y_2 = x^2 + x^3$,當 \(x=1\)\(\frac {dy_3} {dx}\) 不應該是5嗎?事實上,由於 \(y_2\) 的定義是被torch.no_grad():包裹的,所以與 \(y_2\) 有關的梯度是不會回傳的,只有與 \(y_1\) 有關的梯度才會回傳,即 \(x^2\)\(x\) 的梯度。

上面提到,y2.requires_grad=False,所以不能調用 y2.backward(),會報錯:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

此外,如果我們想要修改tensor的數值,但是又不希望被autograd記錄(即不會影響反向傳播),那么我么可以對tensor.data進行操作。

x = torch.ones(1,requires_grad=True)

print(x.data) # 還是一個tensor
print(x.data.requires_grad) # 但是已經是獨立於計算圖之外

y = 2 * x
x.data *= 100 # 只改變了值,不會記錄在計算圖,所以不會影響梯度傳播

y.backward()
print(x) # 更改data的值也會影響tensor的值
print(x.grad)

輸出:

tensor([1.])
False
tensor([100.], requires_grad=True)
tensor([2.])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM