反向求导


1.2 神经网络的反向求导

在上一节中, 我们大致对神经网络的梯度更新有了了解,其中最核心的部分就是求出损失函数对权重 𝑤𝑙𝑖𝑗wijl 的导数。由于网上大多数资料都是生搬硬套,因此我们以计算 𝑊1W1 的导数为例,对整个反向求导过程进行细致的剖析。如下图所示:

其中,𝑤𝑙𝑗𝑘wjkl 表示从第 𝑙l 层的第 𝑗j 个节点到第 𝑙+1l+1 层中的第 𝑘k 个节点的权重,根据前向传播的计算我们可以得到:

 

𝑦𝑜𝑢𝑡(𝑤311𝑤211+𝑤321𝑤212)𝑤111𝑥1,𝑦𝑜𝑢𝑡𝑤111=(𝑤311𝑤211+𝑤321𝑤212)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤211+𝑤321𝑤212)𝑤121𝑥2,𝑦𝑜𝑢𝑡𝑤121=(𝑤311𝑤211+𝑤321𝑤212)𝑥2𝑦𝑜𝑢𝑡(𝑤311𝑤221+𝑤321𝑤222)𝑤112𝑥1,𝑦𝑜𝑢𝑡𝑤112=(𝑤311𝑤221+𝑤321𝑤222)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤221+𝑤321𝑤222)𝑤122𝑥1,𝑦𝑜𝑢𝑡𝑤122=(𝑤311𝑤221+𝑤321𝑤222)𝑥2𝑦𝑜𝑢𝑡(𝑤311𝑤231+𝑤321𝑤232)𝑤113𝑥1,𝑦𝑜𝑢𝑡𝑤113=(𝑤311𝑤231+𝑤321𝑤232)𝑥1;𝑦𝑜𝑢𝑡(𝑤311𝑤231+𝑤321𝑤232)𝑤123𝑥2,𝑦𝑜𝑢𝑡𝑤123=(𝑤311𝑤231+𝑤321𝑤232)𝑥2yout∼(w113w112+w213w122)w111x1,∂yout∂w111=(w113w112+w213w122)x1;yout∼(w113w112+w213w122)w211x2,∂yout∂w211=(w113w112+w213w122)x2yout∼(w113w212+w213w222)w121x1,∂yout∂w121=(w113w212+w213w222)x1;yout∼(w113w212+w213w222)w221x1,∂yout∂w221=(w113w212+w213w222)x2yout∼(w113w312+w213w322)w131x1,∂yout∂w131=(w113w312+w213w322)x1;yout∼(w113w312+w213w322)w231x2,∂yout∂w231=(w113w312+w213w322)x2

 

用矩阵表示为:

 

𝐿𝑊1=⎡⎣⎢⎢⎢⎢⎢⎢⎢𝑦𝑜𝑢𝑡𝑤111𝑦𝑜𝑢𝑡𝑤112𝑦𝑜𝑢𝑡𝑤113𝑦𝑜𝑢𝑡𝑤121𝑦𝑜𝑢𝑡𝑤122𝑦𝑜𝑢𝑡𝑤123⎤⎦⎥⎥⎥⎥⎥⎥⎥=([𝑤311𝑤211+𝑤321𝑤212𝑤311𝑤221+𝑤321𝑤222𝑤311𝑤221+𝑤321𝑤232][𝑥1𝑥2])𝑇=(𝑊3𝑊2𝑋)𝑇∂L∂W1=[∂yout∂w111∂yout∂w211∂yout∂w121∂yout∂w221∂yout∂w131∂yout∂w231]=([w113w112+w213w122w113w212+w213w222w113w212+w213w322]⊙[x1x2])T=(W3W2⊙X)T

 

因此,整个反向传播的过程如下:

首先计算:𝐿𝑊3=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑊3)𝑇=𝐿𝑦𝑜𝑢𝑡[𝑦𝑜𝑢𝑡𝑤311,𝑦𝑜𝑢𝑡𝑤312]𝑇=𝐿𝑦𝑜𝑢𝑡(𝑍2)𝑇∂L∂W3=∂L∂yout⊙(∂yout∂W3)T=∂L∂yout⊙[∂yout∂w113,∂yout∂w123]T=∂L∂yout⊙(Z2)T

然后计算:𝐿𝑊2=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍2𝑊2)𝑇=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍1)𝑇=𝐿𝑦𝑜𝑢𝑡(𝑊3𝑍1)𝑇∂L∂W2=∂L∂yout(∂yout∂Z2⊙∂Z2∂W2)T=∂L∂yout(∂yout∂Z2⊙Z1)T=∂L∂yout(W3⊙Z1)T

最后计算:𝐿𝑊1=𝐿𝑦𝑜𝑢𝑡(𝑦𝑜𝑢𝑡𝑍2𝑍2𝑍1𝑍1𝑊1)𝑇=𝐿𝑊1(𝑊3𝑊2𝑋)𝑇∂L∂W1=∂L∂yout(∂yout∂Z2∂Z2∂Z1⊙∂Z1∂W1)T=∂L∂W1(W3W2⊙X)T

为了方便计算,反向传播通过使用计算图的形式在 Tensorflow,PyTorch 等深度学习框架中实现,将上述过程绘制成计算图如下:

根据计算图,可以轻而易举地计算出损失函数对每个变量的导数。


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM