關於RNN (循環神經網絡)相鄰采樣為什么在每次迭代之前都需要將參數detach
-
這個問題出自《動手學深度學習pytorch》中RNN 第六章6.4節內容,如下圖所示:
當時看到這個注釋,我是一臉懵逼,(難道就不能解釋清楚一點嘛,讓我獨自思考了那么長時間,差評!!!)我主要有以下疑惑:
-
每次小批量反向傳播之后,由於torch是動態計算圖,本質上該次的計算圖已經銷毀,與下次小批量迭代的構建的計算沒有任何關聯,detach不是多此一舉嘛?
-
按照注釋所說的,難道下次小批量構建的計算圖由於初始隱藏狀態引用於上次小批量迭代最后的時間步長的隱藏狀態,這樣計算圖存在分支關聯,方向傳播會經過以前所有批量迭代構建的計算圖,導致內存和計算資源開銷大?
- 帶着這兩個疑惑,我開始面向百度編程(網上的博客真的是千篇一律啊,10篇當中9篇一樣,哎世風日下,我也是服了,文章轉來轉去有意思嘛,自己收藏着看看不好嘛,非得全篇復制還轉載,真的***)百度之后,我發現了以下解釋(沒一個有用的)
-
胡說八道型
這講的啥?按你這么說,state是葉子節點了(估計不知道從哪抄的錯誤博客,害人匪淺啊),既然state都是葉子節點了,那還跟上一次批量的計算圖有毛關系,反向傳播個屁?葉子節點的定義:一棵樹當中沒有子結點(即度為0)的結點稱為葉子結點。除了第一次小批量的初始隱藏狀態是葉子節點外,其他批量的隱藏狀態都經過隱藏層的計算,所以state已經不再是葉子節點了,而是分支節點(即grad_fn屬性不為None的節點)不信,現場測試:
將源代碼略微添加以上代碼,驗證是否為葉子節點:
看出來了吧,除了第一個小批量state 是葉子節點,其他都不是。
-
理解不到位型
哎,這張禍害不淺的知乎轉載圖:Z不是葉子節點,他是經過計算的節點(其他內容不粘貼了)
-
既然不是葉子節點,那detach到底有什么作用呢
首先要明確一個意識:pytorch是動態計算圖,每次backward后,本次計算圖自動銷毀,但是計算圖中的節點都還保留。
方向傳播直到葉子節點為止,否者一直傳播,直到找到葉子節點
我的答案是有用,但根本不是為了防止梯度開銷過大(注釋真的害人不淺啊),detach的真正作用是梯度節流,防止反向傳播傳播到隱藏狀態時,因為上次小批量方向傳播計算圖的銷毀導致繼續向下傳播而引起報錯。啥意思呢,我以連續兩次小批量迭代舉例:
第一次小批量迭代,H0 是葉子節點,因為他沒經過任何計算。剩余H1是非葉子節點。在第一次方向傳播后,第一次的計算圖已經銷毀,但是節點數據仍然存在。
第二次小批量迭代,第一次批量迭代的最后時間節點的隱藏狀態H2 成為第二批次小的初始隱藏狀態( H0(第二次) = H2(第一次) ),這樣第二次在方向傳播時,當傳播到H0時,發現H0 是 分支節點(grad_fn+requires_grad) ,就會繼續向下傳播直到找到葉子節點為止,但是可惜的是H0 之后的計算圖(即第一次小批量的計算圖)已經銷毀,傳播發生中斷,因此就會導致出錯。而使用detach之后,H0 自然與上次的計算圖沒有任何關系,H0自身變為葉子節點,這樣傳播到H0時自然就結束了。
好了,驗證我所說的吧。
- 首先,不使用detach,會導致傳播報錯
將detach 操作刪除
運行結果:
看到沒,第二次在方向傳播時出錯了吧
-
使用detach,防止出錯,並使H0 變為葉子節點
代碼更改如下:
結果:全是true
綜上:detach在這里作用,大家明白不,喜歡點個贊!!!!
至於書中為什么將detach的作用注釋成那樣呢,我想作者在翻譯成torch的時候,忽略了MAXNET框架(原書是maxnet框架)與pytorch的區別。 MaxNet是支持靜態圖的,所以對於MaxNet ,detach的作用是與注釋相同的,但是pytorch是動態圖,所以作用在這里就不同了!!!