反向求導


1.2 神經網絡的反向求導

在上一節中, 我們大致對神經網絡的梯度更新有了了解,其中最核心的部分就是求出損失函數對權重 𝑤𝑙𝑖𝑗wijl 的導數。由於網上大多數資料都是生搬硬套,因此我們以計算 𝑊1W1 的導數為例,對整個反向求導過程進行細致的剖析。如下圖所示:

其中,𝑤𝑙𝑗𝑘wjkl 表示從第 𝑙l 層的第 𝑗j 個節點到第 𝑙+1l+1 層中的第 𝑘k 個節點的權重,根據前向傳播的計算我們可以得到:

 

𝑦𝑜𝑢𝑡(𝑤311𝑤211+𝑤321𝑤212)𝑤111𝑥1,𝑦𝑜𝑢𝑡𝑤111=(𝑤311𝑤211+𝑤321𝑤212)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤211+𝑤321𝑤212)𝑤121𝑥2,𝑦𝑜𝑢𝑡𝑤121=(𝑤311𝑤211+𝑤321𝑤212)𝑥2𝑦𝑜𝑢𝑡(𝑤311𝑤221+𝑤321𝑤222)𝑤112𝑥1,𝑦𝑜𝑢𝑡𝑤112=(𝑤311𝑤221+𝑤321𝑤222)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤221+𝑤321𝑤222)𝑤122𝑥1,𝑦𝑜𝑢𝑡𝑤122=(𝑤311𝑤221+𝑤321𝑤222)𝑥2𝑦𝑜𝑢𝑡(𝑤311𝑤231+𝑤321𝑤232)𝑤113𝑥1,𝑦𝑜𝑢𝑡𝑤113=(𝑤311𝑤231+𝑤321𝑤232)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤231+𝑤321𝑤232)𝑤123𝑥2,𝑦𝑜𝑢𝑡𝑤123=(𝑤311𝑤231+𝑤321𝑤232)𝑥2yout∼(w113w112+w213w122)w111x1,∂yout∂w111=(w113w112+w213w122)x1;yout∼(w113w112+w213w122)w211x2,∂yout∂w211=(w113w112+w213w122)x2yout∼(w113w212+w213w222)w121x1,∂yout∂w121=(w113w212+w213w222)x1;yout∼(w113w212+w213w222)w221x1,∂yout∂w221=(w113w212+w213w222)x2yout∼(w113w312+w213w322)w131x1,∂yout∂w131=(w113w312+w213w322)x1;yout∼(w113w312+w213w322)w231x2,∂yout∂w231=(w113w312+w213w322)x2

 

用矩陣表示為:

 

𝐿𝑊1=⎡⎣⎢⎢⎢⎢⎢⎢⎢𝑦𝑜𝑢𝑡𝑤111𝑦𝑜𝑢𝑡𝑤112𝑦𝑜𝑢𝑡𝑤113𝑦𝑜𝑢𝑡𝑤121𝑦𝑜𝑢𝑡𝑤122𝑦𝑜𝑢𝑡𝑤123⎤⎦⎥⎥⎥⎥⎥⎥⎥=([𝑤311𝑤211+𝑤321𝑤212𝑤311𝑤221+𝑤321𝑤222𝑤311𝑤221+𝑤321𝑤232][𝑥1𝑥2])𝑇=(𝑊3𝑊2𝑋)𝑇∂L∂W1=[∂yout∂w111∂yout∂w211∂yout∂w121∂yout∂w221∂yout∂w131∂yout∂w231]=([w113w112+w213w122w113w212+w213w222w113w212+w213w322]⊙[x1x2])T=(W3W2⊙X)T

 

因此,整個反向傳播的過程如下:

首先計算:𝐿𝑊3=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑊3)𝑇=𝐿𝑦𝑜𝑢𝑡[𝑦𝑜𝑢𝑡𝑤311,𝑦𝑜𝑢𝑡𝑤312]𝑇=𝐿𝑦𝑜𝑢𝑡(𝑍2)𝑇∂L∂W3=∂L∂yout⊙(∂yout∂W3)T=∂L∂yout⊙[∂yout∂w113,∂yout∂w123]T=∂L∂yout⊙(Z2)T

然后計算:𝐿𝑊2=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍2𝑊2)𝑇=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍1)𝑇=𝐿𝑦𝑜𝑢𝑡(𝑊3𝑍1)𝑇∂L∂W2=∂L∂yout(∂yout∂Z2⊙∂Z2∂W2)T=∂L∂yout(∂yout∂Z2⊙Z1)T=∂L∂yout(W3⊙Z1)T

最后計算:𝐿𝑊1=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍2𝑍1𝑍1𝑊1)𝑇=𝐿𝑊1(𝑊3𝑊2𝑋)𝑇∂L∂W1=∂L∂yout(∂yout∂Z2∂Z2∂Z1⊙∂Z1∂W1)T=∂L∂W1(W3W2⊙X)T

為了方便計算,反向傳播通過使用計算圖的形式在 Tensorflow,PyTorch 等深度學習框架中實現,將上述過程繪制成計算圖如下:

根據計算圖,可以輕而易舉地計算出損失函數對每個變量的導數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM