Kaldi attention解析


xconfig示例

num_targets=3766

learning_rate_factor=20

dir=`mktemp -d`

mkdir -p $dir/configs

cat <<EOF > $dir/configs/network.xconfig

input dim=71 name=input

attention-relu-renorm-layer name=attention1 num-heads=5 value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 time-stride=3

output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5

EOF

(cd ~/kaldi/egs/wsj/s5;steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/)

config示例

component name=attention1.attention type=RestrictedAttentionComponent value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 num-left-inputs-required=-1 num-right-inputs-required=-1 output-context=True time-stride=3 num-heads=5 key-scale=0.158113883008

component-node name=attention1.attention component=attention1.attention input=attention1.affine

raw.txt示例

<ComponentName> attention1.attention <RestrictedAttentionComponent> <NumHeads> 5 <KeyDim> 20 <ValueDim> 40 <NumLeftInputs> 5 <NumRightInputs> 2 <TimeStride> 3 <NumLeftInputsRequired> 5 <NumRightInputsRequired> 2 <OutputContext> T <KeyScale> 0.1581139 <StatsCount> 0 <EntropyStats> [ ]

<PosteriorStats> [ ]

</RestrictedAttentionComponent>

拓撲結構

根據拓撲結構可知,kaldi nnet3 RestrictedAttentionComponent相當於一個非線性層

gdb示例

$ gdb -d ~/kaldi/src/nnet3 --args nnet3-compute ref.raw ark,t:/tmp/feat ark:/dev/null

(gdb) rb kaldi::nnet3::.*::Propagate

(gdb) run

Breakpoint 3, kaldi::nnet3::AffineComponent::Propagate (this=0x11a6ec80, indexes=0x0, in=...,

out=0x7fffffffb790) at nnet-simple-component.cc:1236

輸入為71x71的矩陣

(gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 71, 71, 440

(gdb) c

Breakpoint 43, kaldi::nnet3::RestrictedAttentionComponent::Propagate (this=0x12f07e40, indexes_in=0x12f09a00,

in=..., out=0x7fffffffb790) at nnet-attention-component.cc:134

輸入為71x440的矩陣,分為5heads

Head 1

Head 2

Head 3

Head 4

Head 5

71x88

71x88

71x88

71x88

71x88

(gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 440, 50, 240

此處對每個head分別進行attention,即PropagateOneHead

(gdb) c

Breakpoint 44, kaldi::nnet3::RestrictedAttentionComponent::PropagateOneHead (this=this@entry=0x12f07e40,

io=..., in=..., c=c@entry=0x7fffffffb630, out=out@entry=0x7fffffffb650) at nnet-attention-component.cc:164

164 CuMatrixBase<BaseFloat> *out) const {

(gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

71, 88, 50, 48

71幀中包含了

  1. num-left-inputs*time-stride=5*3=15幀左上文,不輸出
  2. 中間50幀,輸出
  3. num-right-inputs*time-stride=2*3=6幀右上文,不輸出

PropagateOneHead的計算示例為:

   

整個RestrictedAttentionComponent的計算邏輯圖為:

   


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM