xconfig示例
num_targets=3766
learning_rate_factor=20
dir=`mktemp -d`
mkdir -p $dir/configs
cat <<EOF > $dir/configs/network.xconfig
input dim=71 name=input
attention-relu-renorm-layer name=attention1 num-heads=5 value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 time-stride=3
output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5
EOF
(cd ~/kaldi/egs/wsj/s5;steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/)
config示例
component name=attention1.attention type=RestrictedAttentionComponent value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 num-left-inputs-required=-1 num-right-inputs-required=-1 output-context=True time-stride=3 num-heads=5 key-scale=0.158113883008 component-node name=attention1.attention component=attention1.attention input=attention1.affine |
raw.txt示例
<ComponentName> attention1.attention <RestrictedAttentionComponent> <NumHeads> 5 <KeyDim> 20 <ValueDim> 40 <NumLeftInputs> 5 <NumRightInputs> 2 <TimeStride> 3 <NumLeftInputsRequired> 5 <NumRightInputsRequired> 2 <OutputContext> T <KeyScale> 0.1581139 <StatsCount> 0 <EntropyStats> [ ] <PosteriorStats> [ ] </RestrictedAttentionComponent> |
拓扑结构
根据拓扑结构可知,kaldi nnet3 RestrictedAttentionComponent相当于一个非线性层
gdb示例
$ gdb -d ~/kaldi/src/nnet3 --args nnet3-compute ref.raw ark,t:/tmp/feat ark:/dev/null
(gdb) rb kaldi::nnet3::.*::Propagate
(gdb) run
Breakpoint 3, kaldi::nnet3::AffineComponent::Propagate (this=0x11a6ec80, indexes=0x0, in=...,
out=0x7fffffffb790) at nnet-simple-component.cc:1236
输入为71x71的矩阵
(gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()
71, 71, 71, 440
(gdb) c
Breakpoint 43, kaldi::nnet3::RestrictedAttentionComponent::Propagate (this=0x12f07e40, indexes_in=0x12f09a00,
in=..., out=0x7fffffffb790) at nnet-attention-component.cc:134
输入为71x440的矩阵,分为5个heads:
Head 1 |
Head 2 |
Head 3 |
Head 4 |
Head 5 |
71x88 |
71x88 |
71x88 |
71x88 |
71x88 |
(gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()
71, 440, 50, 240
此处对每个head分别进行attention,即PropagateOneHead
(gdb) c Breakpoint 44, kaldi::nnet3::RestrictedAttentionComponent::PropagateOneHead (this=this@entry=0x12f07e40, io=..., in=..., c=c@entry=0x7fffffffb630, out=out@entry=0x7fffffffb650) at nnet-attention-component.cc:164 164 CuMatrixBase<BaseFloat> *out) const { (gdb) printf "%d, %d, %d, %d\n", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols() 71, 88, 50, 48 |
71帧中包含了
- num-left-inputs*time-stride=5*3=15帧左上文,不输出
- 中间50帧,输出
- num-right-inputs*time-stride=2*3=6帧右上文,不输出
PropagateOneHead的计算示例为:
整个RestrictedAttentionComponent的计算逻辑图为: