Pytorch學習筆記10----LSTM循環神經網絡原理


1.RNN的構造過程

RNN是一種特殊的神經網路結構,其本身是包含循環的網絡,允許信息在神經元之間傳遞,如下圖所示:

 

圖示是一個RNN結構示意圖,圖中的 [公式] 表示神經網絡模型,[公式] 表示模型的輸入信號,[公式] 表示模型的輸出信號,如果沒有 [公式] 的輸出信號傳遞到 [公式] 的那個箭頭, 這個網絡模型與普通的神經網絡結構無異。那么這個箭頭做了什么事情呢?它允許 [公式] 將信息傳遞給 [公式] ,神經網絡將自己的輸出作為輸入了!

關鍵在於輸入信號是一個時間序列,跟時間 [公式] 有關。也就是說,在 [公式] 時刻,輸入信號 [公式] 作為神經網絡 [公式] 的輸入,[公式] 的輸出分流為兩部分,一部分輸出給 [公式] ,一部分作為一個隱藏的信號流被輸入到 [公式] 中,在下一次時刻輸入信號 [公式] 時,這部分隱藏的信號流也作為輸入信號輸入到了 [公式] 中。此時神經網絡 [公式] 就同時接收了 [公式] 時刻和 [公式] 時刻的信號輸入了,此時的輸出信號又將被傳遞到下一時刻的 [公式] 中。如果我們把上面那個圖根據時間 [公式] 展開來看,就是:

 

循環神經網絡的記憶性: 

普通的神經網絡:當我們輸入一張卡比獸被噴水的圖片時,神經網絡會認出卡比獸和水,推斷出卡比獸有60%的概率在洗澡,30%的概率在喝水,10%的概率被攻擊。

循環神經網絡:在隱藏狀態(Hidden State)為“戰斗場景開始”的情況下輸入神奇寶貝噴水進攻圖,RNN能夠根據“嘴中噴水”的場景推測圖一神奇寶貝是在進攻的概率為85%。之后我們在記憶為“在戰斗、敵人在攻擊和敵人是水性攻擊”三個條件下輸入圖片二,RNN就會分析出“卡比獸被攻擊”是概率最大的情況。

2.長短時間記憶網絡LSTM概述

長短期記憶(Long Short Term Memory,LSTM)網絡是一種特殊的RNN模型,其特殊的結構設計使得它可以避免長期依賴問題,記住很早時刻的信息是LSTM的默認行為,而不需要專門為此付出很大代價。

粗看起來,這個結構有點復雜,不過不用擔心,接下來我們會慢慢解釋。在解釋這個神經網絡層時我們先來認識一些基本的模塊表示方法。圖中的模塊分為以下幾種:

  • 黃色方塊:表示一個神經網絡層(Neural Network Layer);
  • 粉色圓圈:表示按位操作或逐點操作(pointwise operation),例如向量加和、向量乘積等;
  • 單箭頭:表示信號傳遞(向量傳遞);
  • 合流箭頭:表示兩個信號的連接(向量拼接);
  • 分流箭頭:表示信號被復制后傳遞到2個不同的地方。

3.LSTM的基本思想

LSTM的關鍵是細胞狀態(直譯:cell state),表示為 [公式] ,用來保存當前LSTM的狀態信息並傳遞到下一時刻的LSTM中,也就是RNN中那根“自循環”的箭頭。當前的LSTM接收來自上一個時刻的細胞狀態 [公式] ,並與當前LSTM接收的信號輸入 [公式] 共同作用產生當前LSTM的細胞狀態 [公式],具體的作用方式下面將詳細介紹。

在LSTM中,采用專門設計的“門”來引入或者去除細胞狀態 [公式] 中的信息。門是一種讓信息選擇性通過的方法。有的門跟信號處理中的濾波器有點類似,允許信號部分通過或者通過時被門加工了;有的門也跟數字電路中的邏輯門類似,允許信號通過或者不通過。這里所采用的門包含一個 [公式] 神經網絡層和一個按位的乘法操作,如下圖所示:

其中黃色方塊表示[公式]神經網絡層,粉色圓圈表示按位乘法操作。[公式]神經網絡層可以將輸入信號轉換為 [公式] 到 [公式] 之間的數值,用來描述有多少量的輸入信號可以通過。[公式] 表示“不允許任何量通過”,[公式] 表示“允許所有量通過”。[公式]神經網絡層起到類似下圖的[公式]函數所示的作用:

其中,橫軸表示輸入信號,縱軸表示經過sigmod函數以后的輸出信號。

LSTM主要包括三個不同的門結構:遺忘門、記憶門和輸出門。這三個門用來控制LSTM的信息保留和傳遞,最終反映到細胞狀態 [公式] 和輸出信號 [公式] 。如下圖所示:

圖中標示了LSTM中各個門的構成情況和相互之間的關系,其中:

  • 遺忘門由一個[公式]神經網絡層和一個按位乘操作構成;
  • 記憶門由輸入門(input gate)與tanh神經網絡層和一個按位乘操作構成;
  • 輸出門(output gate)與 [公式] 函數(注意:這里不是 [公式] 神經網絡層)以及按位乘操作共同作用將細胞狀態和輸入信號傳遞到輸出端。

4.LSTM門講解

(1)遺忘門

顧名思義,遺忘門的作用就是用來“忘記”信息的。在LSTM的使用過程中,有一些信息不是必要的,因此遺忘門的作用就是用來選擇這些信息並“忘記”它們。遺忘門決定了細胞狀態 [公式] 中的哪些信息將被遺忘。那么遺忘門的工作原理是什么呢?看下面這張圖。

 

左邊高亮的結構就是遺忘門了,包含一個[公式]神經網絡層(黃色方框,神經網絡參數為 [公式]),接收 [公式] 時刻的輸入信號 [公式] 和 [公式] 時刻LSTM的上一個輸出信號 [公式] ,這兩個信號進行拼接以后共同輸入到[公式]神經網絡層中,然后輸出信號 [公式][公式]是一個 [公式] 到[公式]之間的數值,並與 [公式] 相乘來決定 [公式]中的哪些信息將被保留,哪些信息將被舍棄。

(2)記憶門

記憶門的作用與遺忘門相反,它將決定新輸入的信息 [公式] 和 [公式] 中哪些信息將被保留。

如圖所示,記憶門包含2個部分。第一個是包含[公式]神經網絡層(輸入門,神經網絡網絡參數為 [公式])和一個 [公式] 神經網絡層(神經網絡參數為 [公式])。

  • [公式]神經網絡層的作用很明顯,跟遺忘門一樣,它接收 [公式] 和 [公式] 作為輸入,然后輸出一個 [公式] 到 [公式] 之間的數值 [公式] 來決定哪些信息需要被更新;
  • Tanh神經網絡層的作用是將輸入的 [公式] 和 [公式] 整合,然后通過一個[公式]神經網絡層來創建一個新的狀態候選向量 [公式] ,[公式] 的值范圍在 [公式] 到 [公式] 之間。

記憶門的輸出由上述兩個神經網絡層的輸出決定,[公式] 與 [公式] 相乘來選擇哪些信息將被新加入到 [公式] 時刻的細胞狀態 [公式] 中。

(3)更新細胞狀態

有了遺忘門和記憶門,我們就可以更新細胞狀態 [公式] 了。

這里將遺忘門的輸出 [公式] 與上一時刻的細胞狀態 [公式] 相乘來選擇遺忘和保留一些信息,將記憶門的輸出與從遺忘門選擇后的信息加和得到新的細胞狀態 [公式]。這就表示 [公式] 時刻的細胞狀態 [公式] 已經包含了此時需要丟棄的 [公式] 時刻傳遞的信息和 [公式] 時刻從輸入信號獲取的需要新加入的信息 [公式] 。[公式] 將繼續傳遞到 [公式] 時刻的LSTM網絡中,作為新的細胞狀態傳遞下去。

(4)輸出門

前面已經講了LSTM如何來更新細胞狀態 [公式], 那么在 [公式] 時刻我們輸入信號 [公式] 以后,對應的輸出信號該如何計算呢?

如上面左圖所示,輸出門就是將[公式]時刻傳遞過來並經過了前面遺忘門與記憶門選擇后的細胞狀態[公式], 與 [公式] 時刻的輸出信號 [公式] 和 [公式] 時刻的輸入信號 [公式] 整合到一起作為當前時刻的輸出信號。整合的過程如上圖所示,[公式] 和 [公式] 經過一個[公式]神經網絡層(神經網絡參數為 [公式])輸出一個 [公式] 到 [公式] 之間的數值 [公式][公式] 經過一個[公式]函數(注意:這里不是 [公式] 神經網絡層)到一個在 [公式] 到 [公式] 之間的數值,並與[公式] 相乘得到輸出信號 [公式] ,同時 [公式] 也作為下一個時刻的輸入信號傳遞到下一階段。

其中,[公式]函數是激活函數的一種,函數圖像為:

參考文獻:

https://zhuanlan.zhihu.com/p/27345523

https://zhuanlan.zhihu.com/p/104475016


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM