多頭Attention 和 自注意力機制


這個多頭attention確實挺搞的,這個東西繞來繞去,看torch的文檔也看不懂,看源碼也迷迷糊糊的,可能我的智商就是不夠吧。。。枯了

論文里的公式求法,可以看到它因為是self-multiheadsAttention。多頭自注意力機制,所以它這里的Q K V 實際上是同一個東西,也就是最后一維都是相同的。

為什么這里可以直接concat起來,是因為它將Q、K、V最后一維都進行了切割,也就是說,它的多頭attention不是說使用多個attention weight,而是說對不同part部分進行attention。比如論文將Q、K、V最后一個維度切成了8塊,它的8頭attention,就是每個attention就對這一塊部分進行attention機制,最后進行concat。這也是一個有意思的點,這樣就直接用點積attention來一次矩陣乘法就行了。

image-20211119095811841

這里有個參考的回答:

為什么切割方式求attention

這里有兩張參考的圖片:

image-20211119105801580

img

torch 文檔

image-20211119105306596

這里的embed_dim 就是后面Q的dim(最后一維)也就是詞向量的維度,這是模型輸出的維度,默認q、k、v的最后維度一致。

image-20211119105942893

key_padding_mask 是padding mask 是掩key的

attn_mask 是掩key_value pair的。

這么說可能很難理解,key_padding_mask就是說句子序列中有多少個padding,這些padding是不要的。但是attn_mask 是用來說,我不能提前看到后面的詞。(這個還是在自注意力機制用到),因為transformer的decoder第一層的自注意力層不能看到未來的詞。

自己實現多頭注意力機制

import torch
import torch.nn as nn
import math
from d2l import torch as d2l
class MultiAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, qdim=None, kdim=None, vdim=None, hdim=256, dropout=0.0) -> None:
        super(MultiAttention, self).__init__()
        self.attention = d2l.DotProductAttention(dropout)
        self.num_heads = num_heads
        nn.MultiheadAttention()
        # 先做一個全連接層好把Q、K、V不同維度轉為同一維度
        self.W_q = nn.Linear(qdim, embed_dim)
        self.W_k = nn.Linear(kdim, embed_dim)
        self.W_v = nn.Linear(vdim, embed_dim)
        self.W_o = nn.Linear(embed_dim, emded_dim)
        
    def forward(self, Q, K, V):
        # 注意這里的Q 的shape (batchsize, qn, qdim)
        # K (batchsize, kvn, kdim)
        # V (batchsize, kvn, vdim)
        Q = self.trans(self.W_q(Q), self.num_heads)
        K = self.trans(self.W_k(K), self.num_heads)
        V = self.trans(self.W_v(V), self.num_heads)
        # Q (batchsize *numheads, qn, embed_dim/num_heads)
        output = self.attention(Q, K, V)
        # output shape (batchsize*num_heads, qn, kvn, embed/num_heads)
        # 這里沒有返回attentionweight,但attentionweight的shape (batchsize*num_heads, qn, kvn)
        # output最后一維的embed/num_heads,是因為我們將V的最后一維切割了。
        output = self.retrans(output, self.num_heads) 
        return self.W_o(output)
    
    def trans(self, X, num_heads):
        # X shape (batchsize, 查詢或者‘鍵值對’數, embed_dim)
        X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
        X = X.permute(0, 2, 1, 3)
        X = X.reshape(-1, X.shape[2], X.shape[3])
        return X
    
    def retrans(self, X, num_heads):
        # X shape (batchsize*num_heads, 查詢或者‘鍵值對’數, embed_dim/num_heads)
        X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        X = X.reshape(X.shape[0], X.shape[1], -1)
        return X

attention = MultiAttention(256, 2, 64, 64, 64)
q= k= v= torch.ones((32, 35, 64))
s = attention(q, k, v)

這里我其實寫的很不標准,因為幾個全連接搞得挺混亂的。但其實思想也是一致的。

自注意力機制

image-20211119110854322

可以看到RNN是很有時序性的,它是要求一個一個輸入。CNN也可以保留一定的時序性,因為卷積核的感受野可以保留部分時序信息。但是self-attention機制是完全沒有時序性的,它一次就可以看完全部。

位置編碼

這里就引入了位置編碼這個概念:

X = X + P其中P就是位置編碼,對應的值:

\[\begin{aligned} p_{i, 2 j} &=\sin \left(\frac{i}{10000^{2 j / d}}\right) \\ p_{i, 2 j+1} &=\cos \left(\frac{i}{10000^{2 j / d}}\right) \end{aligned} \]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM