這個多頭attention確實挺搞的,這個東西繞來繞去,看torch的文檔也看不懂,看源碼也迷迷糊糊的,可能我的智商就是不夠吧。。。枯了
論文里的公式求法,可以看到它因為是self-multiheadsAttention。多頭自注意力機制,所以它這里的Q K V
實際上是同一個東西,也就是最后一維都是相同的。
為什么這里可以直接concat起來,是因為它將Q、K、V最后一維都進行了切割,也就是說,它的多頭attention不是說使用多個attention weight,而是說對不同part部分進行attention。比如論文將Q、K、V最后一個維度切成了8塊,它的8頭attention,就是每個attention就對這一塊部分進行attention機制,最后進行concat。這也是一個有意思的點,這樣就直接用點積attention來一次矩陣乘法就行了。
這里有個參考的回答:
這里有兩張參考的圖片:
torch 文檔
這里的embed_dim 就是后面Q的dim(最后一維)也就是詞向量的維度,這是模型輸出的維度,默認q、k、v的最后維度一致。
key_padding_mask 是padding mask 是掩key的
attn_mask 是掩key_value pair的。
這么說可能很難理解,key_padding_mask就是說句子序列中有多少個padding,這些padding是不要的。但是attn_mask 是用來說,我不能提前看到后面的詞。(這個還是在自注意力機制用到),因為transformer的decoder第一層的自注意力層不能看到未來的詞。
自己實現多頭注意力機制
import torch
import torch.nn as nn
import math
from d2l import torch as d2l
class MultiAttention(nn.Module):
def __init__(self, embed_dim, num_heads, qdim=None, kdim=None, vdim=None, hdim=256, dropout=0.0) -> None:
super(MultiAttention, self).__init__()
self.attention = d2l.DotProductAttention(dropout)
self.num_heads = num_heads
nn.MultiheadAttention()
# 先做一個全連接層好把Q、K、V不同維度轉為同一維度
self.W_q = nn.Linear(qdim, embed_dim)
self.W_k = nn.Linear(kdim, embed_dim)
self.W_v = nn.Linear(vdim, embed_dim)
self.W_o = nn.Linear(embed_dim, emded_dim)
def forward(self, Q, K, V):
# 注意這里的Q 的shape (batchsize, qn, qdim)
# K (batchsize, kvn, kdim)
# V (batchsize, kvn, vdim)
Q = self.trans(self.W_q(Q), self.num_heads)
K = self.trans(self.W_k(K), self.num_heads)
V = self.trans(self.W_v(V), self.num_heads)
# Q (batchsize *numheads, qn, embed_dim/num_heads)
output = self.attention(Q, K, V)
# output shape (batchsize*num_heads, qn, kvn, embed/num_heads)
# 這里沒有返回attentionweight,但attentionweight的shape (batchsize*num_heads, qn, kvn)
# output最后一維的embed/num_heads,是因為我們將V的最后一維切割了。
output = self.retrans(output, self.num_heads)
return self.W_o(output)
def trans(self, X, num_heads):
# X shape (batchsize, 查詢或者‘鍵值對’數, embed_dim)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
X = X.reshape(-1, X.shape[2], X.shape[3])
return X
def retrans(self, X, num_heads):
# X shape (batchsize*num_heads, 查詢或者‘鍵值對’數, embed_dim/num_heads)
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
X = X.reshape(X.shape[0], X.shape[1], -1)
return X
attention = MultiAttention(256, 2, 64, 64, 64)
q= k= v= torch.ones((32, 35, 64))
s = attention(q, k, v)
這里我其實寫的很不標准,因為幾個全連接搞得挺混亂的。但其實思想也是一致的。
自注意力機制
可以看到RNN是很有時序性的,它是要求一個一個輸入。CNN也可以保留一定的時序性,因為卷積核的感受野可以保留部分時序信息。但是self-attention機制是完全沒有時序性的,它一次就可以看完全部。
位置編碼
這里就引入了位置編碼這個概念:
X = X + P
其中P就是位置編碼,對應的值: