注意力機制和Seq2Seq模型
1. 基本概念
Attention 是一種通用的帶權池化方法,輸入由兩部分構成:詢問(query)和鍵值對(key-value pairs)。\(𝐤_𝑖∈ℝ^{𝑑_𝑘}, 𝐯_𝑖∈ℝ^{𝑑_𝑣}\). Query \(𝐪∈ℝ^{𝑑_𝑞}\) , attention layer得到輸出與value的維度一致 \(𝐨∈ℝ^{𝑑_𝑣}\). 對於一個query來說,attention layer 會與每一個key計算注意力分數並進行權重的歸一化,輸出的向量\(o\)則是value的加權求和,而每個key計算的權重與value一一對應。
為了計算輸出,我們首先假設有一個函數\(\alpha\) 用於計算query和key的相似性,然后可以計算所有的 attention scores \(a_1, \ldots, a_n\) by
我們使用 softmax函數 獲得注意力權重:
最終的輸出就是value的加權求和:
不同的attetion layer的區別在於score函數的選擇,下面主要討論兩個常用的注意層 Dot-product Attention 和 Multilayer Perceptron Attention
Softmax屏蔽
在深入研究實現之前,首先介紹softmax操作符的一個屏蔽操作,主要目的是屏蔽無關信息。
def SequenceMask(X, X_len,value=-1e6):
maxlen = X.size(1)
#print(X.size(),torch.arange((maxlen),dtype=torch.float)[None, :],'\n',X_len[:, None] )
mask = torch.arange((maxlen),dtype=torch.float)[None, :] >= X_len[:, None]
#print(mask)
X[mask]=value
return X
def masked_softmax(X, valid_length):
# X: 3-D tensor, valid_length: 1-D or 2-D tensor
softmax = nn.Softmax(dim=-1)
if valid_length is None:
return softmax(X)
else:
shape = X.shape
if valid_length.dim() == 1:
try:
valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))#[2,2,3,3]
except:
valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))#[2,2,3,3]
else:
valid_length = valid_length.reshape((-1,))
# fill masked elements with a large negative, whose exp is 0
X = SequenceMask(X.reshape((-1, shape[-1])), valid_length)
return softmax(X).reshape(shape)
masked_softmax(torch.rand((2,2,4),dtype=torch.float), torch.FloatTensor([2,3]))
tensor([[[0.4047, 0.5953, 0.0000, 0.0000],
[0.4454, 0.5546, 0.0000, 0.0000]],
[[0.3397, 0.3389, 0.3213, 0.0000],
[0.3526, 0.3318, 0.3156, 0.0000]]])
超出2維矩陣的乘法
\(X\) 和 \(Y\) 是維度分別為\((b,n,m)\) 和\((b, m, k)\)的張量,進行 \(b\) 次二維矩陣乘法后得到 \(Z\), 維度為 \((b, n, k)\)。
torch.bmm(torch.ones((2,1,3), dtype = torch.float), torch.ones((2,3,2), dtype = torch.float))
tensor([[[3., 3.]],
[[3., 3.]]])
2. 兩種常用的attention層
2.1點積注意力
The dot product 假設query和keys有相同的維度, 即 $\forall i, 𝐪,𝐤_𝑖 ∈ ℝ_𝑑 $. 通過計算query和key轉置的乘積來計算attention score,通常還會除去 \(\sqrt{d}\) 減少計算出來的score對維度𝑑的依賴性,如下
假設 $ 𝐐∈ℝ^{𝑚×𝑑}$ 有 \(m\) 個query,\(𝐊∈ℝ^{𝑛×𝑑}\) 有 \(n\) 個keys. 我們可以通過矩陣運算的方式計算所有 \(mn\) 個score:
下面來實現這個層,它支持一批查詢和鍵值對。此外,它支持作為正則化隨機刪除一些注意力權重.
# Save to the d2l package.
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# query: (batch_size, #queries, d)
# key: (batch_size, #kv_pairs, d)
# value: (batch_size, #kv_pairs, dim_v)
# valid_length: either (batch_size, ) or (batch_size, xx)
def forward(self, query, key, value, valid_length=None):
d = query.shape[-1]
# set transpose_b=True to swap the last two dimensions of key
scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
attention_weights = self.dropout(masked_softmax(scores, valid_length))
print("attention_weight\n",attention_weights)
return torch.bmm(attention_weights, value)
測試
創建兩個batch,每個batch有一個query和10個key-values對。通過valid_length指定,對於第一批,只關注前2個鍵-值對,而對於第二批,我們將檢查前6個鍵-值對。盡管這兩個批處理具有相同的查詢和鍵值對,但獲得的輸出是不同的。
atten = DotProductAttention(dropout=0)
keys = torch.ones((2,10,2),dtype=torch.float)
values = torch.arange((40), dtype=torch.float).view(1,10,4).repeat(2,1,1)
print(values.shape)
atten(torch.ones((2,1,2),dtype=torch.float), keys, values, torch.FloatTensor([2, 6]))
torch.Size([2, 10, 4])
attention_weight
tensor([[[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
0.0000, 0.0000]]])
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]])
2.2 多層感知機注意力
在多層感知器中,我們首先將 query and keys 投影到 \(ℝ^ℎ\) .為了更具體,我們將可以學習的參數做如下映射
\(𝐖_𝑘∈ℝ^{ℎ×𝑑_𝑘}\) , \(𝐖_𝑞∈ℝ^{ℎ×𝑑_𝑞}\) , and \(𝐯∈ℝ^h\) . 將score函數定義
.
然后將key 和 value 在特征的維度上合並(concatenate),然后送至 a single hidden layer perceptron 這層中 hidden layer 為 ℎ and 輸出的size為 1 .隱層激活函數為tanh,無偏置.
# Save to the d2l package.
class MLPAttention(nn.Module):
def __init__(self, units,ipt_dim,dropout, **kwargs):
super(MLPAttention, self).__init__(**kwargs)
# Use flatten=True to keep query's and key's 3-D shapes.
self.W_k = nn.Linear(ipt_dim, units, bias=False)
self.W_q = nn.Linear(ipt_dim, units, bias=False)
self.v = nn.Linear(units, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, valid_length):
query, key = self.W_k(query), self.W_q(key)
#print("size",query.size(),key.size())
# expand query to (batch_size, #querys, 1, units), and key to
# (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
features = query.unsqueeze(2) + key.unsqueeze(1)
#print("features:",features.size()) #--------------開啟
scores = self.v(features).squeeze(-1)
attention_weights = self.dropout(masked_softmax(scores, valid_length))
return torch.bmm(attention_weights, value)
測試
盡管MLPAttention包含一個額外的MLP模型,但如果給定相同的輸入和相同的鍵,我們將獲得與DotProductAttention相同的輸出
atten = MLPAttention(ipt_dim=2,units = 8, dropout=0)
atten(torch.ones((2,1,2), dtype = torch.float), keys, values, torch.FloatTensor([2, 6]))
torch.Size([2, 1, 10])
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward>)
3. 帶注意力機制的Seq2seq模型
解碼器
在解碼的每個時間步,使用解碼器的最后一個RNN層的輸出作為注意層的query。然后,將注意力模型的輸出與輸入嵌入向量連接起來,輸入到RNN層。雖然RNN層隱藏狀態也包含來自解碼器的歷史信息,但是attention model的輸出顯式地選擇了enc_valid_len以內的編碼器輸出,這樣attention機制就會盡可能排除其他不相關的信息。
class Seq2SeqAttentionDecoder(d2l.Decoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
self.attention_cell = MLPAttention(num_hiddens,num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.LSTM(embed_size+ num_hiddens,num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens,vocab_size)
def init_state(self, enc_outputs, enc_valid_len, *args):
outputs, hidden_state = enc_outputs
# print("first:",outputs.size(),hidden_state[0].size(),hidden_state[1].size())
# Transpose outputs to (batch_size, seq_len, hidden_size)
return (outputs.permute(1,0,-1), hidden_state, enc_valid_len)
#outputs.swapaxes(0, 1)
def forward(self, X, state):
enc_outputs, hidden_state, enc_valid_len = state
#("X.size",X.size())
X = self.embedding(X).transpose(0,1)
# print("Xembeding.size2",X.size())
outputs = []
for l, x in enumerate(X):
# print(f"\n{l}-th token")
# print("x.first.size()",x.size())
# query shape: (batch_size, 1, hidden_size)
# select hidden state of the last rnn layer as query
query = hidden_state[0][-1].unsqueeze(1) # np.expand_dims(hidden_state[0][-1], axis=1)
# context has same shape as query
# print("query enc_outputs, enc_outputs:\n",query.size(), enc_outputs.size(), enc_outputs.size())
context = self.attention_cell(query, enc_outputs, enc_outputs, enc_valid_len)
# Concatenate on the feature dimension
# print("context.size:",context.size())
x = torch.cat((context, x.unsqueeze(1)), dim=-1)
# Reshape x to (1, batch_size, embed_size+hidden_size)
# print("rnn",x.size(), len(hidden_state))
out, hidden_state = self.rnn(x.transpose(0,1), hidden_state)
outputs.append(out)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.transpose(0, 1), [enc_outputs, hidden_state,
enc_valid_len]
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
# encoder.initialize()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
X = torch.zeros((4, 7),dtype=torch.long)
print("batch size=4\nseq_length=7\nhidden dim=16\nnum_layers=2\n")
print('encoder output size:', encoder(X)[0].size())
print('encoder hidden size:', encoder(X)[1][0].size())
print('encoder memory size:', encoder(X)[1][1].size())
state = decoder.init_state(encoder(X), None)
out, state = decoder(X, state)
out.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
batch size=4
seq_length=7
hidden dim=16
num_layers=2
encoder output size: torch.Size([7, 4, 16])
encoder hidden size: torch.Size([2, 4, 16])
encoder memory size: torch.Size([2, 4, 16])
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([2, 4, 16]))