kaldi中CD-DNN-HMM網絡參數更新公式手寫推導


在基於DNN-HMM的語音識別中,DNN的作用跟GMM是一樣的,即它是取代GMM的,具體作用是算特征值對每個三音素狀態的概率,算出來哪個最大這個特征值就對應哪個狀態。只不過以前是用GMM算的,現在用DNN算了。這是典型的多分類問題,所以輸出層用的激活函數是softmax,損失函數用的是cross entropy(交叉熵)。不用均方差做損失函數的原因是在分類問題上它是非凸函數,不能保證全局最優解(只有凸函數才能保證全局最優解)。Kaldi中也支持DNN-HMM,它還依賴於上下文(context dependent, CD),所以叫CD-DNN-HMM。在kaldi的nnet1中,特征提取用filterbank,每幀40維數據,默認取當前幀前后5幀加上當前幀共11幀作為輸入,所以輸入層維數是440(440 = 40*11)。同時默認有4個隱藏層,每層1024個網元,激活函數是sigmoid。今天我們看看網絡的各種參數是怎么得到的(手寫推導)。由於真正的網絡比較復雜,為了推導方便這里對其進行了簡化,只有一個隱藏層,每層的網元均為3,同時只有weight沒有bias。這樣網絡如下圖:

上圖中輸入層3個網元為i1/i2/i3(i表示input),隱藏層3個網元為h1/h2/h3(h表示hidden),輸出層3個網元為o1/o2/o3(o表示output)。隱藏層h1的輸入為 (q11等表示輸入層和隱藏層之間的權值),輸出為。輸出層o1的輸入為(w11等表示隱藏層和輸出層之間的權值),輸出為。其他可類似推出。損失函數用交叉熵。今天我們看看網絡參數(以隱藏層和輸出層之間的w11以及輸入層和隱藏層之間的q11為例)在每次迭代訓練后是怎么更新的。先看隱藏層和輸出層之間的w11。

 

1,隱藏層和輸出層之間的w11的更新

 

 先分別求三個導數的值:

 

 所以最終的w11更新公式如下圖:

 

2,輸入層和隱藏層之間的q11的更新

 

先分別求三個導數的值:

 

所以最終的q11更新公式如下圖:

 

以上的公式推導中如有錯誤,煩請指出,非常感謝!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM