語音識別算法閱讀之speechTransformer


 論文:
  SPEECH-TRANSFORMER: A NO-RECURRENCE SEQUENCE-TO-SEQUENCE MODELFOR SPEECH RECOGNITION
思路:
  1)整體采用seq2seq的encoder和decoder架構;
  2)借助transformer對文本位置信息進行學習;
  3)相對於RNN,transformer可並行化訓練,加速了訓練過程;
  4)論文提出了2D-attention結構,能夠對時域和頻域兩個維度進行建模,使得attention結構能更好的捕獲時域和空間域信息
模型:
  speech-transformer 整體采用encoder和decoder結構,其中encoder和decoder的主要模塊都是multi-head attention和feed-forward network;此外,encoder為更好的對時域和空域不變性建模,還額外添加了conv結構和2D-attention
  • conv:encoder采用了兩層3*3,stride=2的conv,對時域和頻域進行卷積,一方面提升模型學習時域信息能力;另一方面縮減時間片維度至跟目標輸出長度相近,節約計算和緩解特征序列和目標序列長度長度不匹配問題;conv的激活為ReLU
  • multi-head attention: encoder和decoder都采用了多層multi-head attention來獲取區分性更強的隱層表達(不同的head采用的變換不同,最后將不同變換后輸出進行拼接,思想有點類似於模型融合);multi-head attention結構由多個並行的scaled dot-product attention組成,在訓練時可並行計算
  • scaled dot-product attention:結構有三個輸入Q1(tq*dq),K1(tk*dk),V1(tv*dv);輸出維度為tq*dv;基本思想類似於attention的注意力機制,Q跟K的運算softmax(QKT)可以看作是計算相應的權重因子,用來衡量V中各個維度特征的重要性;縮放因子√dk的作用在論文中提到是為了緩解當dk過大時帶來的softmax剃度過小問題;mask分為padding mask和掩蔽mask,前者主要是用於解決padding后的序列中0造成的影響,后者主要是解碼階段不能看到未來的文本信息
√d k的解釋在[1]中為:
  • multi-head attention:結構有三個輸入Q0(tq*dmodel)、K0(tk*dmodel)、V0(tv*dmodel),分別經過h次不同的線性變換(WQi(dmodel*dq)、WKi(dmodel*dk)、WVi(dmodel*dv),i=1,2,3...,h),輸入到h個分支scaled dot-product attention,各個分支的輸出維度為tq*dv(dv=dmodel/h),這樣經過concat后維度變成tq*hdv,再經過最后的線性層WO(hdv*dmodel)之后就得到了最終的tq*dmodel
 注意這里的Q、K、V與scaled dot-product attention不等價,所以我這里用Q 0、K 0、V 0以作區分
  • feed-forward network:前饋網絡包含一個全連接層和一個線性層,全連接層激活為ReLU
 其中W 1(d model*d ff),W 2(d ff*d model),b 1(d ff),b 2(d model)
  • Positional Encoding:因為transformer中不包含RNN和conv,所以其對序列的位置信息相對不敏感,於是在輸入時引入與輸入編碼相同維度的位置編碼,增強序列的相對位置和絕對位置信息。
 其中,pos代表序列位置,i表示特征的第i個維度,PE (pos,i)一方面可以衡量位置pos的絕對位置信息,另一方面因為 sin(a+b)=sina*cosb+cosa*sinb、cos(a+b)=sina*cosb+cosa*sinb,所以對於位置p的相對位置k,PE (pos+k)可以表示PE pos的線性變換;這樣同時引入了序列的絕對和相對位置信息。
  • resnet-block:論文中引入了resnet中的跳躍連接,將底層的輸入不經過網絡直接傳遞到上層,減緩信息的流失和提升訓練穩定性
  • Layer Norm:論文采用LN,對每一層的神經元輸入進行歸一化,加速訓練
 
 
 其中,l為網絡的第l層,H為第l層的神經元節點數,g,b分別為學習參數,使得特征變換成歸一化前的特性,f為激活函數,h為輸出。
  • 2D-attention:transformer中的attention結構僅僅針對時域的位置相關性進行建模,但是人類在預測發音時同時依賴時域和頻域變化,所以作者在此基礎上提出了2D-attention結構,即對時域和頻域的位置相關性均進行建模,有利於增強模型對時域和頻域的不變性。
  1. 輸入:c通道的輸入I
  2. 卷積層:三個卷積層分別作用域I,獲得相應的Q、K、V,濾波器分別為WQi、WKi、WVi(i=1,2,3...,c)
  3. 兩個multi-head scaled dot-product attention分別對時域和頻域進行建模,獲取相應的時間和頻域依賴,head數為c
  4. 對時域和頻域的attention輸出進行concat,得到2c通道的輸出,並輸入到卷積中得到n通道的輸出
其中W o為最后一層卷積,t為時間軸,f為頻率軸
訓練:
  • 數據集:WSJ(Wall Street Journal),train si284/dev dev93/ test dev92
  • 輸入特征:80維fbank
  • 輸出單元:輸出單元數為31,包含26個小寫字母以及撇號,句點,空格,噪聲和序列結束標記
  • GPU型號:NVIDIA K80 GPU 100steps
  • 優化方法和學習率:Adam
 
 其中,n為steps數,k為縮放因子,warmupn= 25000,k= 10
  • 訓練時,在每一個resnet-block和attention中添加dropout(丟棄率為0.1);其中,resnet-block的dropout在殘差添加之前;attention的dropout在每個softmax激活之前
  • 所有用到的conv的輸出通道數固定為64,且添加BN加速訓練過程
  • 訓練之后,對最后得到的10個模型進行平均,得到最終模型
  • 解碼時采用beam width=10 beam search,此外length歸一化的權重因子為1
實驗:
  在實驗中dmodel固定為256,head數固定為4
  • encoder的深度相比於decoder深度更有利於模型效果
  • 論文提出的2D-attention相比於ResCNN, ResCNNLSTM效果更好;表現2D-attention可以更好的對時域和頻域相關性進行建模
  • encoder使用較多層數的resnet-block時(比如12),額外添加2D-attention對識別效果沒有提升,分析原因是當encoder達到足夠深度后,對聲學信息的提取和表達能力以及足夠,更多是無益
  • 訓練時間上,相比於seq2seq結構,在取得相似的識別效果的同時,訓練速度提升4.25倍
  • 環境:pytorch>=1.20;Torchaudio >= 0.3.0
  • 輸入特征:40fbank,CMVN=False
  • spec-augment[3]:頻率掩蔽+時間掩蔽,忽略時間扭曲(復雜度大,提升不明顯)
  • 模型結構:
  1. encoder:2*conv(3*2)->1*linear+pos embedding->6*(multi-head attention(head=4, d_model=320)+ffn(1280))
  2. decoder:6*(multi-head attention(head=4, d_model=320)+ffn(1280))
  • 語言模型:4*(multi-head attention(head=4, d_model=320)+ffn(1280))
  • 訓練:
  1. 優化算法:adam
  2. 學習率策略:stepwise=12000,學習率在前warmup_n迭代步數線性上升,在n_step^(-0.5)迭代次數時停止下降
  3. clip_grad=5
  4. label smoothing[4]:平滑參數0.1
  5. batch=16
  6. epoch=80
  • 解碼:beam search(beam width=5)+長度懲罰(權重因子=0.6)
  • 實驗效果:aishell test cer:6.7%
Reference:
[1] Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf
[2] SPEECH-TRANSFORMER: A NO-RECURRENCE SEQUENCE-TO-SEQUENCE MODEL FOR SPEECH RECOGNITION http://150.162.46.34:8080/icassp2018/ICASSP18_USB/pdfs/0005884.pdf


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM