Transformer中的維度變換


自己總結記錄一下transformer中的維度變換

對於輸入

input: [batch_size * max_sen_len]

詞嵌入矩陣

vocab_matrix dim: [vocab_size * embedding_dim]

位置編碼

PE(pos,2i)=sin(pos/10000^(2i/embedding_dim))
PE(pos,2i+1)=cos(pos/10000^(2i/embedding_dim))

encoder input embedding x = input token emb + position emb :
[batch_size * max_sen_len * embedding_dim]

對每一句話(句尾</s>):[ max_sen_len * embedding_dim ]

ENCODER

流程:

input -> dropout ->
(multihead SAN -> attention dropout -> residual connection -> LN -> FFN -> dropout -> RS connection-> LN) * 6 ->
[batch_size, max_sen_len, embedding_dim]

---- multihead self atten ----
WQ,WK,WV: embedding_dim * embedding_dim,
其中WQ, WK, WV可以切分為多頭WQ_i, Wk_i, WV_i, 即第二個維度 = embedding_dim//num_heads=d_k
WQ_i,Wk_i,WV_i: embedding_dim * d_k
q_i,k_i,v_i = x * (WQ_i,WK_i,WV_i) : max_sen_len * d_k

weight compute:
q_i * k_i / sqrt(d_k) : [max_sen_len * max_sen_len]

(softmax之前要對q和k做mask,把pad 0的維度置為-inf,這樣softmax之后對應位置權重為0)

softmax(q_i * k_i / sqrt(d_k) + Mask) * v_i = head_i, 在最后一個維度上做softmax
head_i: [max_sen_len * d_k]
Multi_head = concat num_heads of head_i = [head_1,head_2,...,head_8]: [max_sen_len * embedding_dim]
W_outlayer : [ embedding_dim , embedding_dim ]
#context = Multi_head * W_outlayer :[max_sen_len * embedding_dim]

---- add & norm ----
[max_sen_len * embedding_dim]

----ffn & add & norm ----
ffn = Relu(W_1 * x + b_1) * W_2 +b_2
Relu = max(0,x)
W_1 : [embedding_dim * ffn_hidden_size]
b_1 : [1 * ffn_hidden_size ]
W_2 : [ffn_hidden_size * embedding_dim]
b_2 : [1 * embedding_dim]

---- enc out ----
[batch_size, max_sen_len, embedding_dim]

DECODER

流程:

decoder input -> droput ->
(masked multihead self atten -> attention dropout -> RS connection-> LN ->
multihead self atten -> dropout -> RS connection-> LN ->
FFN -> dropout -> RS connection-> LN) *6 ->
[batch_size, max_sen_len, vocab_size]

decoder input embedding y = input token emb + position emb :
[ batch_size * max_sen_len, embedding_dim]
對每一句話y(要添加起始符號<s>) : [ max_sen_len * embedding_dim ]

ENCODER的輸出給每一層DECODER
---- masked multihead self atten ----
上三角矩陣置為-inf
q,k 來自encoder輸出:[max_sen_len, embedding_dim]
q_i,k_i,v_i = y * WQ_i,WK_i,WV_i : [max_sen_len * d_k]

weight compute:
q_i * k_i / srqt(d_k) : [max_sen_len * max_sen_len]

softmax(q_i * k_i / sqrt(d_k) + Mask) * v_i = head_i : [max_sen_len * d_k]
Multi_head = concat num_heads of head_i = [head_1,head_2,...,head_8]: [max_sen_len * embedding_dim]
W_outlayer : [ embedding_dim , embedding_dim ]
#context = Multi_head * W_outlayer :[max_sen_len * embedding_dim]

---- multihead self att ----
維度變換同上
[max_sen_len * embedding_dim]

---- add & norm ----
[max_sen_len * embedding_dim]

----ffn & add & norm ----
ffn = Relu(W_1 * y + b_1) * W_2 +b_2
Relu = max(0,y)
W_1 : [embedding_dim * ffn_hidden_size]
b_1 : [1 * ffn_hidden_size]
W_2 : [ffn_hidden_size * vocab_size]
b_2 : [1 * vocab_size]

[batch_size, max_sen_len, vocab_size]

---- dec out ----
[batch_size, max_sen_len, vocab_size]
decoder輸出隱藏層變量,先乘以線性矩陣,再在最后一維做softmax(vocab_size維),得到詞典庫上的概率分布,
輸出最大的概率,與真實標簽進行交叉熵損失的計算,匯總一句話中每個的損失,優化,訓練


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM